Detr源码解读(mmdetection)

2023-11-02

Detr源码解读(mmdetection)

1、原理简要介绍

在这里插入图片描述
整体流程: 在给定一张输入图像后,1)特征向量提取: 首先经过ResNet提取图像的最后一层特征图F。注意此处仅仅用了一层特征图,是因为后续计算复杂度原因,另外,由于仅用最后一层特征图,故对小目标检测不友好,这也是后续deformable detr改进的原因。 2)添加位置编码信息: 经F拉平成一维张量并添加上位置编码信息得到I。3)Transformer中encoder部分4)Transformer中decoder部分,学习位置嵌入object queries。5)FFN部分:6)后续匈牙利匹配+损失计算。

2、mmdetection中源码介绍

2.1. 整体逻辑

Detr的内部逻辑如下:在mmdet/models/detector/single_stage.py。即首先提取图像特征向量,之后经过DetrHead来计算最终的损失。img[b,3,224,224] x[b,2048,7,7]

def forward_train(self,
                  img,
                  img_metas,
                  gt_bboxes,
                  gt_labels,
                  gt_bboxes_ignore=None):
    super(SingleStageDetector, self).forward_train(img, img_metas)
    # img[b,3,224,224] x[b,2048,7,7]
    x = self.extract_feat(img) # 提取图像特征向量  
    # 经过DetrHead得到loss                   
    losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
                                          gt_labels, gt_bboxes_ignore)
    return losses

forward_train
跟其他的检测头差不多,先是调用自己,也就是自身的 forward 函数,得到输出的 class label 和 reg coordinate,再调用自身的 loss 函数,不过这里是重载了一下,将 img_meta 传输进了 forward 函数的参数。执行完outs = self(x, img_metas)跳转到forward的num_levels = len(feats)

def forward_train(self,
                      x,
                      img_metas,
                      gt_bboxes,
                      gt_labels=None,
                      gt_bboxes_ignore=None,
                      proposal_cfg=None,
                      **kwargs):
        """Forward function for training mode.

        Args:
            x (list[Tensor]): Features from backbone.
            img_metas (list[dict]): Meta information of each image每个图像的元信息, e.g.,
                image size, scaling factor, etc.
            gt_bboxes (Tensor): Ground truth bboxes of the image,图像的地面真相框
                shape (num_gts, 4).
            gt_labels (Tensor): Ground truth labels of each box,
                shape (num_gts,).
            gt_bboxes_ignore (Tensor): Ground truth bboxes to be ignored,要忽略的基本事实框,
                shape (num_ignored_gts, 4).
            proposal_cfg (mmcv.Config): Test / postprocessing configuration,测试/后处理配置
                if None, test_cfg would be used.

        Returns:
            dict[str, Tensor]: A dictionary of loss components.损失成分词典。
        """
        assert proposal_cfg is None, '"proposal_cfg" must be None'
        outs = self(x, img_metas) #x[b,2048,7,7]
        if gt_labels is None:
            loss_inputs = outs + (gt_bboxes, img_metas)
        else:
            loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
        losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        return losses

执行完outs = self(x, img_metas)跳转到forward的num_levels = len(feats)。feats[b,2048,7,7]

    def forward(self, feats, img_metas):
        #这里默认为1,因为DETR默认用最后一层特征图
        num_levels = len(feats)
        img_metas_list = [img_metas for _ in range(num_levels)]
        return multi_apply(self.forward_single, feats, img_metas_list)

执行完return multi_apply(self.forward_single, feats, img_metas_list)跳转到forward_single函数

2.2. 图像特征向量提取

mmdet中提取图像特征向量的config配置文件如下,可以发现用ResNet50并只提取了最后一层特征层,即out_indices=(3,)。骨干网络会输出特征图的1/32,输入为【2,3,224,224】。通过backbone后得到图像大小为【2, 2048, 7, 7】和mask大小为【2,7,7】

backbone=dict(
    type='ResNet',
    depth=50,
    num_stages=4,
    out_indices=(3, ),     # detr仅要resnet50的最后一层特征图,并不需要FPN
    frozen_stages=1,
    norm_cfg=dict(type='BN', requires_grad=False),
    norm_eval=True,
    style='pytorch',
    init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'))

2.3. 给图像特征向量添加位置编码信息(forward_single函数,里面是 head 前向的逻辑)。

本部分代码来自mmdet/models/dense_heads/detr_head.py的 forward_single函数中。
  mmdet中生成位置编码信息借助的是mask矩阵(所谓的mask就是为了统一批次大小而对图像进行了pad,被填充的部分在后续计算多头注意力时应该舍弃)故需要一个mask矩阵遮挡住,具体形状为[batch, h,w]这里先贴下生成mask的过程:

batch_size = x.size(0)   
input_img_h, input_img_w = img_metas[0]['batch_input_shape']# 一个批次图像大小
# 先将 mask 设置为全 1
masks = x.new_ones((batch_size, input_img_h, input_img_w))  # [b,224,224]
# 对每一张图来说,在原来图片有像素的地方把 mask 置 0
# 因此 mask 中 padding 的地方才是 1
for img_id in range(batch_size):
    img_h, img_w, _ = img_metas[img_id]['img_shape']    # 创建了一个mask,非0代表无效区域, 0 代表有效区域
    masks[img_id, :img_h, :img_w] = 0                   # 将pad部分置为1,非pad部分置为0.

输入图像的经过resnet50下采样后hw已经变了,所以还需进一步将mask下采样成和图像特征向量一样的shape。代码如下:

# 将每一层的特征图先投影到指定的特征维度,2048通道太多了转成256通道
x = self.input_proj(x) #Conv2d(self.in_channels, self.embed_dims, kernel_size=1)#[b,256,7,7]
# interpolate masks to have the same spatial shape with x
masks = F.interpolate(                                                         #masks[b,7,7]
    masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1) # masks和x的shape一样:[b,2,2]

后续便可以生成位置编码部分(mmdet/models/utils/position_encoding.py),代码里采用了sine位置编码,该函数给masks的每个像素位置生成了一个256维的唯一的位置向量。shape:[B, 256, 7, 7]

# position encoding
pos_embed = self.positional_encoding(masks)

2.4 送入Transformer

4.1. 整体逻辑

在得到图像特征向量x=[b,256,7,7]、masks[b,7,7]矩阵以及位置编码pos_embed[b,256,7,7]后,便可送入Transformer。进入transformer的之前四个变量维度分别为, x->[2, 256, 7, 7],mask->[2, 7, 7],query_embed->[100, 256],pos_embed->[2, 256, 7, 7]

# outs_dec: [nb_nb_decdec, bs, num_query, embed_dim]
outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,pos_embed)

在进入transformer之前,定义了一个query_embed(就是后边的object query),其第一个维度为num_queries(原文解释为一张图片里的最大检测数量),第二个维度为hidden_dim,就是256。

self.query_embedding = nn.Embedding(self.num_query, self.embed_dims)

关键是理清encoder和decoder的QKV分别指啥, 本部分代码来自mmdet\models\utils\transformer.py的 Transformer函数中。看代码:

       bs, c, h, w = x.shape
       # use `view` instead of `flatten` for dynamically exporting to ONNX
       x = x.view(bs, c, -1).permute(2, 0, 1)  # [bs, c, h, w] -> [h*w, bs, c]  [49,2,256]
       pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1)                 # [49,2,256]
       query_embed = query_embed.unsqueeze(1).repeat(                         #[100,b,256]
           1, bs, 1)  # [num_query, dim] -> [num_query, bs, dim]
       mask = mask.view(bs, -1)  # [bs, h, w] -> [bs, h*w]   [2,49]
       """
       经过变换后的四个变量维度分别为, img->[49, 2, 256],mask->[2, 49],
       query_embed->[100, 2, 256],pos_embed->[49, 2, 256]
       """
       
       memory = self.encoder(
           query=x,                            # [49,b,256]
           key=None,
           value=None,
           query_pos=pos_embed,                 # [49,b,256]
           query_key_padding_mask=mask)  # [b,49]
       target = torch.zeros_like(query_embed) # decoder初始化全0
       # out_dec: [num_layers, num_query, bs, dim]
       out_dec = self.decoder(
           query=target,              # 全0的target, 后续在MultiHeadAttn中执行了
           key=memory,              # query = query + query_pos又加回去了。
           value=memory,
           key_pos=pos_embed,
           query_pos=query_embed, # [num_query, bs, dim]
           key_padding_mask=mask)
       # outs_dec: [nb_nb_decdec, bs, num_query, embed_dim] [6,2,100,256]
       out_dec = out_dec.transpose(1, 2)
       memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)
       return out_dec, memory

其中encoder中q就是x,kv分别为None,query_pos代表位置编码,而query_key_padding_mask就是mask。decoder的q是全0的target,后续decoder会迭代更新q,而kv则 是memory,即encoder的输出;key_pos依旧是k的位置信息;query_embed即论文中Object query,可学习位置信息;key_padding_mask依然是mask。

4.2. encoder部分

先看下encoder初始化部分,内部循环调用了6次BaseTransformerLayer,因此只需讲解一层EncoderLayer即可。将img,mask,pos_embed送入transformer encoder中,进行注意力操作。得到[49, 2, 256]的输出

encoder=dict(
    type='DetrTransformerEncoder',
    num_layers=6,                        # 经过6层Layer
    transformerlayers=dict(              # 每层layer内部使用多头注意力
        type='BaseTransformerLayer',
        attn_cfgs=[
            dict(
                type='MultiheadAttention',
                embed_dims=256,           
                num_heads=8,
                dropout=0.1)
        ],
        feedforward_channels=2048,        # FFN中间层的维度   
        ffn_dropout=0.1,
        operation_order=('self_attn', 'norm', 'ffn', 'norm'))), # 定义运算流程

先跳转到mmdet\models\utils\transformer.py的DetrTransformerEncoder函数。再来看下BaseTransformerLayer的forward部分。该部分可以损失detr的核心部分了,因为本质上mmdet内部只是封装了pytorch现有的nn.MultiHeadAtten函数。所以,需要理解nn.MultiHeadAttn中两种mask参数的含义,限于篇幅原因,这里可参考nn.Transformer来理解这两个mask。 不过简单理解就是:attn_mask在detr中没用到,仅用key_padding_mask。attn_mask是为了遮挡未来文本信息用的,而图像可以看到全部的信息,因此不需要用attn_mask。
在这里插入图片描述

def forward(self,
            query,
            key=None,
            value=None,
            query_pos=None,
            key_pos=None,
            attn_masks=None,
            query_key_padding_mask=None,
            key_padding_mask=None,
            **kwargs):
    #Forward function for `TransformerDecoderLayer`.
    norm_index = 0
    attn_index = 0
    ffn_index = 0
    identity = query
    if attn_masks is None:
        attn_masks = [None for _ in range(self.num_attn)]
    elif isinstance(attn_masks, torch.Tensor):
        attn_masks = [
            copy.deepcopy(attn_masks) for _ in range(self.num_attn)
        ]
        warnings.warn(f'Use same attn_mask in all attentions in '
                      f'{self.__class__.__name__} ')
    else:
        assert len(attn_masks) == self.num_attn, f'The length of ' \
                    f'attn_masks {len(attn_masks)} must be equal ' \
                    f'to the number of attention in ' \
                    f'operation_order {self.num_attn}'

    for layer in self.operation_order:                  # 遍历config文件的顺序
        if layer == 'self_attn':
            temp_key = temp_value = query 
            query = self.attentions[attn_index](        # 内部调用nn.MultiHeadAttn
                query,
                temp_key,
                temp_value,
                identity if self.pre_norm else None,
                query_pos=query_pos,                    # 若有位置编码信息则和query相加 
                key_pos=query_pos,                       # 若有位置编码信息则和key相加 
                attn_mask=attn_masks[attn_index],
                key_padding_mask=query_key_padding_mask,
                **kwargs)
            attn_index += 1
            identity = query

        elif layer == 'norm':
            query = self.norms[norm_index](query)      # 层归一化
            norm_index += 1

        elif layer == 'cross_attn':                    # decoder用到
            query = self.attentions[attn_index](     
                query,
                key,
                value,
                identity if self.pre_norm else None,
                query_pos=query_pos,                   # 若有位置编码信息则和query相加 
                key_pos=key_pos,                        # 若有位置编码信息则和key相加 
                attn_mask=attn_masks[attn_index],
                key_padding_mask=key_padding_mask,
                **kwargs)
            attn_index += 1
            identity = query

        elif layer == 'ffn':                         # 残差连接加全连接层
            query = self.ffns[ffn_index](
                query, identity if self.pre_norm else None)
            ffn_index += 1
    return query

decoder部分和encoder流程类似,只是多了交叉注意力。decoder部分将[49,2,256]的输出和query_embed[100,2,256]输入到transformer decoder中,得到[6, 2, 100, 256]的输出。这里是合并了6个不同层级解码层的输出,其实只需要最后一层即可。
decoder这里其实是将query_embed和feature做了注意力机制,q为query_embed[100, 2, 256],k为memory也就是feature[49, 2, 256],v也是memory[49, 2, 256]。

**

总结

**
decoder的输出经过Prediction feed-forward networks (FFNs)生成最终的预测。即[6,2,100,256]经过线性层生成[6,2,100,92]的类别预测,经过线性层生成[6, 2, 100, 4]的框坐标预测。
由于后续在detr上改进的论文对匈牙利算法以及loss计算改动不大,因此这部分代码就不讲解了。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Detr源码解读(mmdetection) 的相关文章

  • Swift3.0 -- get & set 方法

    1 Swift中的get set 方法 定义一个Person类 class Person NSObject getter setter 仅供演示 日常开发不用 var name String var name String get 返回成员
  • 针对urdf文件中出现中文注释导致报错的问题

    ubuntu18 04 melodic版本中一般使用python2 7环境 该环境下urdf文件中不能有中文注释出现 否则报错 UnicodeEncodeError ascii codec can t encode characters i
  • 【华为OD机试2023】二进制差异数 100% C++ Java Python

    华为OD机试2023 二进制差异数 100 C Java Python 前言 如果您在准备华为的面试 期间有想了解的可以私信我 我会尽可能帮您解答 也可以给您一些建议 本文解法非最优解 即非性能最优 不能保证通过率 Tips1 机试为ACM
  • Servlet的4种会话管理技术

    Servlet的会话管理技术主要有4种 URL重写 Cookie 隐藏表单域以及HttpSession 下面的例子通过一个动态验证码的实例详细说明了这4种技术的具体应用方式 为了简单起见 动态验证码没有采用图片生成技术 其原理和效果与本例中
  • 国际货运代理行业分析

    国际货运代理面临的困局 是一个时代的结束 https zhuanlan zhihu com p 68044703 国际货运代理行业的未来 https zhuanlan zhihu com p 86243754 国际货运代理的存在有何意义 h
  • Spring boot 2.7以上版本集成swagger

    pom xml 引入依赖
  • 使用轻量级 CDC debezium-server-databend 构建实时数据同步

    作者 韩山杰 Databend Cloud 研发工程师 hantmac Jeremy GitHub Debezium Server Databend 是一个基于 Debezium Engine 自研的轻量级 CDC 项目 用于实时捕获数据库
  • .NET 正则表达式

    using System Text RegularExpressions 使用正则表达式替换主题中的特殊符号 用以生成文件夹或者文件名 string title Regex Replace Subject lt gt string patt

随机推荐

  • MacOS set MarkText and PicGo

    MacOS set MarkText and PicGo 192 xxx picgo h Usage picgo options command Options v version output the version number d d
  • win7计算机脱机,win7浏览器解除脱机状态的简单方法(图文)

    大家都知道 脱机工作就是浏览器不使用网络工作 而是使用曾就浏览网页的时候已经下载下来的文件来工作 如果win7纯净版系统浏览器开启了脱机工作 打开陌生的网页的时候就会看到下面的这个错误提示 目前没有可以的Internet 连接 要查看已保存
  • 台式计算机怎么看有没有开独显,怎么看是集显还是独显?台式电脑应该怎么选择...

    怎么看是集显还是独显 台式电脑应该怎么选择 在装机的时候我们经常会听到集显和独显 很多小白也并不知道集显和独显的区别 甚至都不知道怎么看自己的电脑是独显还是集显 下面电脑配置网就来为大家科普一下关于集显和独显的相关问题 显卡是电脑主机不可缺
  • sklearn进行归一化

    利用神经网络预测数据时 突然就被一个以前从来没想过的问题困扰了 训练集与测试集应该分别进行归一化还是合并为一个大的矩阵统一进行归一化 如果放在一起 测试集会参与到模型的训练当中 感觉不对 如果分开 怎么对测试集进行归一化呢 咨询老师 得到了
  • rk3368 Android9.0 升级固件后,一直在recovery界面

    升级固件后 一直在recovery界面 Platform RK3368 OS Android 9 0 Kernel 4 4 194 文章目录 升级固件后 一直在recovery界面 1 串口日志 2 解决方法 1 串口日志 E Failed
  • 欧拉操作系统和linux区别

    centos和Linux区别 linux与centos的区别与联系 1 centos是基于linux建立的操作系统 2 linux属于内核系统 只有终端命令界面 无图形界面 3 centos同时拥有终端命令界面和图形界面 4 linux和c
  • Warning:(23, 8) java: lombok.javac.apt.LombokProcessor could not be initialized.

    项目背景 开发工具idea springboot项目 在使用 Data时 启动项目后控制台报以下信息 Warning 23 8 java lombok javac apt LombokProcessor could not be initi
  • 闭关之 C++ 并发编程笔记(二):同步、内存模型和原子操作

    目录 第4章 并发操作的同步 4 1 等待事件或等待其他条件 4 1 1 凭借条件变量等待条件成立 4 1 2 利用条件变量构建线程安全的队列 4 2 使用future等待一次性事件发生 4 2 1 从后台任务返回值 4 2 2 关联fut
  • 定时删除某目录下几天前的文件

    系统每天生成日志 为了保证系统正常运行 需要不断清理系统空间 就把这种事情交给批处理文件来吧 这种方法简单实用 你只要通过任务计划和批处理文件就能实现 1 任务计划 进入 控制面板 任务计划 然后根据提示选择要运行的程序 运行时间即可 2
  • 基于STM32F103 实现按键状态机

    文章目录 开发板 开发环境 前言 按键消抖 按键硬件原理图 软件延时实现思路 实验目的 代码 按键状态 按键信息 按键相关定义 按键底层配置及状态获取 总结 开发板 正点原子STM32F103ZET6战舰 开发环境 stm32cubeMX
  • 单片机论文参考:1、基于单片机的电子琴

    摘要 随着社会的发展进步 音乐逐渐成为我们生活中很重要的一部分 有人曾说喜欢音乐的人不会向恶 我们都会抽空欣赏世界名曲 作为对精神的洗礼 本论文设计一个基于单片机的简易电子琴 电子琴是现代电子科技与音乐结合的产物 是一种新型的键盘乐器 它在
  • 遇到python调用selenium库使用chrome时候报错 selenium.common.exceptions.WebDriverException问题

    遇到selenium common exceptions WebDriverException问题 selenium common exceptions WebDriverException Message chromedriver exe
  • CTF中,命令中空格被过滤的解决方法

    转载于 作者 HyyMbb 链接 https blog csdn net a3320315 article details 99773192 来源 CSDN博客 1 linux cat flag txt cat IFS flag txt c
  • 【Matlab代码实现】电动过滤器:LPF和HPF、模拟调制:调幅和调频、WiFi、蓝牙和蜂窝网络的容量分析.....

    欢迎来到本博客 博主优势 博客内容尽量做到思维缜密 逻辑清晰 为了方便读者 座右铭 行百里者 半于九十 本文目录如下 目录 1 概述 2 运行结果 3 参考文献 4 Matlab代码实现 1 概述 为了更好 更深入地了解工程通信原理 需要获
  • 矢量图形绘制Illustrator 2022简体中文

    Illustrator 2022简称AI 是简体中文的版本 在矢量图形创作的行业是首选的软件 作为一个行业标准的AI 用途越来越多 比如产品的包装 书籍的插图插画设计 大街上随处可见的广告牌设计 社交媒体等的插图 越来越多的人使用AI进行创
  • 学习笔记(118):R语言入门基础-最年长和最年轻富豪

    立即学习 https edu csdn net course play 24913 285865 utm source blogtoedu na omit 显示非NA的数据
  • latex常见编译错误和细节

    1 nu sqrt mu等一些常用希腊字母和字符不能放在 text 里面 text sqrt text mu 是错误用法 2 使用左右尖括号 left langle和 right rangle 持续更新
  • 性能测试工具 Locust 分布式进行性能测试

    目录 前言 Options 介绍 前言 在进行性能测试时 我们需要确保应用程序在分布式环境下仍能正常运行 使用 Locust 进行性能测试时 当一台单机不足以模拟所需的用户数量的时候 可以在多台机器上分布式的执行性能测试 首先 需要在主机下
  • 认识kubenetes的核心组件之一kubelet

    kubelet是在每个 Node 节点 包括master节点 上运行的主要 节点代理 默认监听10250端口 kubelet主要完成如下一些任务 接收并执行master发送来的指令 管理pod以及pod中的容器 管理节点和更新节点状态信息
  • Detr源码解读(mmdetection)

    Detr源码解读 mmdetection 1 原理简要介绍 整体流程 在给定一张输入图像后 1 特征向量提取 首先经过ResNet提取图像的最后一层特征图F 注意此处仅仅用了一层特征图 是因为后续计算复杂度原因 另外 由于仅用最后一层特征图