从代码角度理解DETR

2023-11-18

  1. 一个cnn的backbone, 提图像的feature, 比如, HWC.
  2. 同时对这个feature做position_embedding.
  3. 然后二者相加 (在Transformer里面就是二者相加)
  4. 输入encoder,
  5. 输入decoder (这里有object queries.)
  6. 然后接Prediction Heads, 比如分类和回归.

下面的代码参考自: https://github.com/facebookresearch/detr
commit-id: 3af9fa8
在这里插入图片描述
可以看到, 这里传入的有backbone, transformer, 输入的类别个数(用来确定head的输出维度), num_quries, 以及是否需要aux_loss等.

先看一下forward的解释
在这里插入图片描述
输入是一堆图片和对应的mask. 这里mask先不管. 后面再来细看其具体起的作用;
输出是logits, boxes, 还有aux_outputs(只有在用aux_loss的时候才会有这个的输出)

接下来看第一步
在这里插入图片描述

backbone部分

从这里在这里插入图片描述
可以看出来, backbone是这两个的结合.
在这里插入图片描述
也就是说最终backbone输出的第一个其实是图像backbone提的feature, 第二个是每个feature所对应的 position encoding.

Transormer部分

transformer的输入如下

在这里插入图片描述
由于没有跑代码. 这里
src: 先理解成是images的feature, pos[-1], 先理解成是position-encoding

这里input_proj. 是对输入的src做了一个FC.
在这里插入图片描述
这里的query_embed是

在这里插入图片描述
可以理解成是一个query的词典表.

之后如图
在这里插入图片描述
encoder部分和原始transformer是一致的, 而decoder部分, 原始transformer输入的是trg_seq. 而这里是一个全0的矩阵. 大小与query_embed一样大.

之前transformer的decoder中是trg_seq 与 src_seq的encoder的output 做encoder-decoder-attention. 但是现在detr里的decoder中, 到现在还没有用到groundtruth.

head部分

在这里插入图片描述
这里bbox的coord取sigmoid的原因是gt是按图像的长宽给的比例.
在这里插入图片描述

aux_loss

loss部分

bbox-loss. 采用的有 l1-loss 以及giou loss, 这里用giou-loss的原因是scale-不变.

在这里插入图片描述

SetCriterion

在这里插入图片描述

这里写得很清晰, 即对gt和dt做一次Hungarian匹配, 然后, 将匹配到的pair, 去算loss, loss包括类别和bbox.
我理解这个过程相当于label-assign, 只不过特殊的地方是这个label-assign 是一个一对一的, 这是和之前一些一阶段和二阶段检测不太一样的地方. 比如使用anchor的方法中, 可能多个anchor会对应到一个gt上面, 这也是为啥那些方法的后处理中要使用NMS, 相当于是一种搜索排查式的检测方式,先检测出一堆proposals, 再选出置信度较高的.

loss_cardinality

这其实不是一个loss, 就是为了统计预测的object的数量与ground-truth数量之间的差异. 用来观测.

HungarianMatcher

这个就是一个匈牙利匹配, 用的 from scipy.optimize import linear_sum_assignment, 这里用pytorch的方式封装了一下.

mask起的作用

二维positionEncoding的细节

paper-reading

Abstract

  1. 把目标检测视为一个集合预测问题. 从设计上去掉了很多的人为操作,比如anchor设定, nms 等.
  2. 更关注object与image context 之间的本质, 直接去预测最终的结果集合. 而非"搜索式检测"
  3. 不需要开发额外的库,比如roi-align, roi-pooling, 这些操作…
  4. 很容易换一个head就可以去做分割的任务,

pipline


整个Pipline看上去很好理解, 细节主要体现在 图像的backbone的features如何转化成为 word-embeding似的输入, 进入到transformer中.

在大目标上面要比小目标上好

在这里插入图片描述
这里解释说在大目标上效果比较好是因为transormer的non-local的机制, 这一点我的理解是, transformer由于内部的self-attention操作, 使得输入的一句话中每一个词彼此之间都会去算attention的加权分数。 所以哪怕是某一个词的预测,它也是依赖于整个句子的. 所以是一个non-local的操作.
而小目标, 因为占据的图像中的位置比较少. 别的位置对于这个小目标的attention不那么重要, 因此这种non-local的操作,对于小目标不太友好.
当然作者也提了可以用其他的方式来缓解小目标不好的问题, 比如FPN.

训练时间长

在这里插入图片描述

对于set prediction问题, 两个重要的部件

set prediction loss

这个主要用于在predictin和ground-truth之间建立one-one map.
DETR是预测N个objects, N是一个超参, 比如100.

能够预测objects以及他们之间关系的模型结构. 这里就是指的是DETR

backbone

在这里插入图片描述
比较好理解,就是正常的 2d-backbone.

encoder

在这里插入图片描述
关键的部分也都在上面标出来了.

decoder

在这里插入图片描述

说实话,没太理解 N个, object-query 到最后为啥能够预测 N个 final predictions. 背后的原因是啥?

prediction Heads

这个就是正常的heads.

auxiliary losses

为了帮助模型训练, 在每个decoder层后面, 加了PredictionHeads 和Hungarian loss 来做监督, 并且这些 层是share 参数的.
在这里插入图片描述

实验对比

encoder层的影响

从下表可以看出来, 用encoder还是很有用的.
在这里插入图片描述

而且有可视化结果证明, encoder 层似乎已经把目标分离开来了.
在这里插入图片描述

decoder层数的影响

在这里插入图片描述

  1. decoder层数的增加, 效果变好,
  2. 当decoder层数只有一层的时候, NMS有用.
  3. 当decoder层数大于一层的时候, NMS几乎没有什么用.
    这说明 单个decoder层不足以表现不同输出间的关系.
    在这里插入图片描述

FFN层的重要性

去掉之后会掉点.
在这里插入图片描述

PositionEncoding 的重要性

  1. spatioal position encoding 在encoder 和decoder 中都非常重要. 没有的话会掉6个点左右.
    在这里插入图片描述

object_query

从代码里看, 是这样的流程.

在这里插入图片描述
比如 num_queries 是100, 而 hidden_dim 是64的话.

那么query_embed.weight 也是 [100, 64] 维的.

这里进去encoder的, tgt 每次其实都是0. 而query_embed.weight 充当的是query_pos.

我理解这个就是上面说的, output position encodings.

因此当normalize_before=False的时候, decoder的时候会直接走forward_post,
从下面的代码可以看出
在这里插入图片描述

在decoder的时候, 最初的一层的输入, 因为tgt都是0, 所以 q=k=query_pos.

所以这里的object_query 其实就是随机初始化的 query_embed.weight.

这里解释下nn.embedding 是什么意思.

nn.embedding可以理解为是一个词嵌入模块. 它是有weight的. 这是一个可以学习的层, 有参数,类似于conv2d.

比如上面的例子中, query_embed 就可以理解为是一个词典, 只不过这个词典有点小, 只有100个词, 每个词的embedding的大小是64.
forward的时候,可以传入indices, 来得到对应的每个单词的embeddings. 可以传入batch的indices.

推理的时候query_embed充当什么角色?

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

从代码角度理解DETR 的相关文章

随机推荐

  • LVGL V9.0基于VS2022仿真搭建(2)

    完整Demo lvgl lvgl drivers相关资料下载 链接 https pan baidu com s 1h3OKCIBQRX0Hn7KjZsynqg 提取码 sc2l 下载的lv drivers中的win32drv c及win32
  • 云链商城连锁门店新零售O20系统以零售商城

    云链商城连锁门店新零售O20系统以零售商城 门店收银 多渠道进销存 客户管理 互动营销 导购助手 多种奖励模式和数据分析等功能 赋能多品牌连锁门店实现线上线下商品 会员 场景的互联互通 助推企业快速实现营销 服务 效率转型升级 为实体零售企
  • Idea中Java项目修改项目名

    1 修改项目文件夹名称 下面是在Idea中改 也可以直接找到项目文件夹重命名 2 关闭项目 Idea会默认生成原项目名的文件夹 将其删除 3 导入重命名后的项目文件夹 4 导入成功后 在Idea中修改模块名称 大功告成 修改项目名总共有三处
  • 【Java】用do-while循环,实现猜数字。

    package TcmStudy day05 import java util Scanner public class DoWhileText01 public static void main String args Scanner i
  • git revert讲解

    git的工作流 工作区 即自己当前分支所修改的代码 git add xx 之前的 不包括 git add xx 和 git commit xxx 之后的 暂存区 已经 git add xxx 进去 且未 git commit xxx 的 本
  • Pandas基本数据对象及操作

    1 Series 创建Series import pandas as pd countries 中国 美国 澳大利亚 countries s pd Series countries print type countries s print
  • HTTP Connection 头(header)说明:keep-alive和closer的区别

    HTTP Connection 头 header 说明 keep alive和closer的区别 前言 在http请求时 我们一般会在request header 或 response header 中看到 Connection Keep
  • IntelliJ IDEA创建Spring Initializr项目!

    目录 1 创建项目 2 点击选择Spring Initializr创建项目 编辑 3 选择项目所需的依赖 4 进入项目后等待加载完成 注意 5 整个项目架构图 编辑 6 项目启动 1 创建项目 一共有两种打开方式 一 在项目里创建Modul
  • 我朋友月薪5w,跟他聊过之后,才知道差距在哪里!

    当我开始工作的时候 年薪50万对于我来说是一个遥不可及的幻想 我认为作为一名普通的软件测试工程师 月薪2w已经是天花板了 然而随着时间的推移和经验的积累 看到越来越多的同行拿到高薪时 我才意识到束缚我薪水的不是行业的天花板 而是我自身技术能
  • 跑pytorch报错: The NVIDIA driver on your system is too old

    今天运行pytorch代码发现报错 The NVIDIA driver on your system is too old found version 8000 Please update your GPU driver by downlo
  • 计算机网络基础应用课程标准,王建波《计算机网络基础》课程标准.doc

    文档介绍 设计者 王建波指导老师 蒋本立廖兴张光清设计时间 2013年7月适用专业 计算机网络专业 计算机应用专业 计算机网络基础 课程标准设计者 王建波指导老师 蒋本立廖兴张光清设计时间 2013年7月适用专业 计算机网络专业 计算机应用
  • 在Raspberry Pi上使用PySimpleGUI创建图表

    PySimpleGUI python库在本地GUI和Web界面具有相同代码的能力中脱颖而出 PySimpleGUI并非以图表包为重点 而是具有画布和图形元素 可让您创建实时条形图和实时趋势图 图形元素入门 图形元素可以具有不同的坐标方向 例
  • 用Dockerfile制作一个python环境案例,值得收藏

    Dockerfile文件 无后缀 FROM python 3 7 设置 python 环境变量 ENV PYTHONUNBUFFERED 1 创建 code 文件夹并将其设置为工作目录 RUN mkdir code WORKDIR code
  • 史上最完美的Android沉浸式状态导航栏攻略

    前言 最近我在小破站开发一款新App 叫高能链 我是一个完美主义者 所以不管对架构还是UI 我都是比较抠细节的 在状态栏和导航栏沉浸式这一块 我还是踩了挺多坑 费了挺多精力的 这次我将我踩坑 适配各机型总结出来的史上最完美的Android沉
  • 傻瓜电梯项目实现

    目录 文档介绍 package lift entity Elevator java Entity java Floor java package lift Pretreatment Pretreatment java package lif
  • Elasticsearch——document相关原理

    1 document数据路由原理 1 1 document路由到shard上是什么意思 一个index的数据会被分为多片 每片都在一个shard中 所以说 一个document 只能存在于一个shard中 当客户端创建document的时候
  • [计算机毕业设计]大数据疫情分析与可视化系统

    前言 大四是整个大学期间最忙碌的时光 一边要忙着准备考研 考公 考教资或者实习为毕业后面临的就业升学做准备 一边要为毕业设计耗费大量精力 近几年各个学校要求的毕设项目越来越难 有不少课题是研究生级别难度的 对本科同学来说是充满挑战 为帮助大
  • mysql报错 -- (errno: 13 - Permission denied)

    重启服务器后 mysql没有自启动 手动启动的时候报错 后面经一番折腾后强行用root身份启动后又发现原有的数据库表都不见了 mysql 报错 ERROR 1018 HY000 Can t read dir of db translator
  • 模型选择+过拟合+欠拟合

    模型选择 当我们训练模型时 我们只能访问数据中的小部分样本 最大的公开图像数据集包含大约一百万张图像 而在大部分时候 我们只能从数千或数万个数据样本中学习 将模型在训练数据上拟合的比在潜在分布中更接近的现象称为过拟合 overfitting
  • 从代码角度理解DETR

    一个cnn的backbone 提图像的feature 比如 HWC 同时对这个feature做position embedding 然后二者相加 在Transformer里面就是二者相加 输入encoder 输入decoder 这里有obj