Deformable Detr代码阅读

2023-10-29

前言

本文主要是自己在阅读mmdet中Deformable Detr的源码时的一个记录,如有错误或者问题,欢迎指正

deformable attention的流程

在这里插入图片描述
首先zq即为object query,通过一个线性层,先预测出offset,后将三组offset添加到reference point上来得到采样后的位置,object query通过一个线性层和softmax,获取到attention weight(这就说明了deformable attention根本不需要用K点乘V来算attention weight,因为其attention weight是通过object query学到的),将attention weight与采样点的feature相乘,就得到了聚合后的value,在通过一个linear,就得到了output

提取feature map

Deformable Detr相对于detr的一个改进就是使用了多尺度的特征图,从配置文件中我们也可以看出

 backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(1, 2, 3),   # 使用了resnet的3层feature map
        frozen_stages=1,
        norm_cfg=dict(type='BN', requires_grad=False),
        norm_eval=True,
        style='pytorch',
        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
    neck=dict(
        type='ChannelMapper',
        in_channels=[512, 1024, 2048],
        kernel_size=1,
        out_channels=256,         # 将三层feature map的输出通道统一为256
        act_cfg=None,
        norm_cfg=dict(type='GN', num_groups=32),
        num_outs=4),

在代码层面,和DETR一样,首先是进入single_stage的forward_train中来提取feature map

super(SingleStageDetector, self).forward_train(img, img_metas)
x = self.extract_feat(img)
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
                                              gt_labels, gt_bboxes_ignore)

这里的x是resnet提取出来的全部四层feature map
在这里插入图片描述
然后进入到detr_head的forward_rain()中(因为是deformable detr head是基础了DETRHead),在DETRHead的forward_train()中,通过下面的代码的进入到deformable detr head的forward中

outs = self(x, img_metas)

deformable detr head

deformable detr head的整体逻辑和detr_head几乎相同,不同之处在于使用了多尺度的feature map。

生成mask矩阵

        batch_size = mlvl_feats[0].size(0)
        input_img_h, input_img_w = img_metas[0]['batch_input_shape']
        img_masks = mlvl_feats[0].new_ones(
            (batch_size, input_img_h, input_img_w))
  # 对于batch_size中的每一个图片,生成相应的原图的mask矩阵,将原始图像部分设置为0,1的位置表示pad部分
        for img_id in range(batch_size):
            img_h, img_w, _ = img_metas[img_id]['img_shape']
            img_masks[img_id, :img_h, :img_w] = 0
		
        mlvl_masks = []
        mlvl_positional_encodings = []
        #对原来的每个img_masks进行下采样,使其和相应的feature map大小相匹配
        for feat in mlvl_feats:
            mlvl_masks.append(
   #索引当中的None是增加维度的作用,img_masks扩充了一个维度:[b,h,w]-->[1,b,h,w]
                F.interpolate(img_masks[None],
                              size=feat.shape[-2:]).to(torch.bool).squeeze(0))
   # 生成positionan encoding,因为mlvl_masks每次append都是在最后一个,所以这里的索引每次取-1就好
            mlvl_positional_encodings.append(
                self.positional_encoding(mlvl_masks[-1]))

mlvl_feats如下所示,我这里batch_size为1
mlvl_fea
这里有一个点值得注意,就是为什么在进行F.interpolate之前要先使用img_masks[None]增加一个维度,这是因为F.interpolate函数对于要采样的矩阵的维度有要求,即为批量(batch_size)×通道(channel)×[可选深度]×[可选高度]×宽度(前两个维度具有特殊的含义,不进行采样处理)
参考:F.interpolate——数组采样操作

进入transformer

在deformable detr head的forward中,通过下面的代码进入transformer

query_embeds = None
if not self.as_two_stage:
            query_embeds = self.query_embedding.weight

hs, init_reference, inter_references, \
enc_outputs_class, enc_outputs_coord = self.transformer(
             mlvl_feats,
             mlvl_masks,
             query_embeds,   #[300,512]  [num_query,embed_dims * 2]
             mlvl_positional_encodings,
             reg_branches=self.reg_branches if self.with_box_refine else None,  # noqa:E501
             cls_branches=self.cls_branches if self.as_two_stage else None  # noqa:E501
            )

代码跳转到DeformableDetrTransformer的forward中,首先会进行一些进入transformer的准备工作

		feat_flatten = []
        mask_flatten = []
        lvl_pos_embed_flatten = []
        spatial_shapes = []
        # 将各个特征层的feature map,mask等拉直
        for lvl, (feat, mask, pos_embed) in enumerate(
                zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
            bs, c, h, w = feat.shape
            spatial_shape = (h, w)
            spatial_shapes.append(spatial_shape)
            feat = feat.flatten(2).transpose(1, 2) # [bs,h*w,c]
            mask = mask.flatten(1)  # [bs,h*w]
            pos_embed = pos_embed.flatten(2).transpose(1, 2)  # [bs,h*w,c]
            lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
            lvl_pos_embed_flatten.append(lvl_pos_embed)
            feat_flatten.append(feat)
            mask_flatten.append(mask)
        feat_flatten = torch.cat(feat_flatten, 1) # [bs,四层的h*w加起来,c]
        mask_flatten = torch.cat(mask_flatten, 1) # [bs,四层的h*w加起来]
        lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # [bs,四层的h*w加起来,c]
        #转成tensor
        spatial_shapes = torch.as_tensor(
            spatial_shapes, dtype=torch.long, device=feat_flatten.device)
        # 记录每一层feature map的起始位置
        level_start_index = torch.cat((spatial_shapes.new_zeros(
            (1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
         #得到每张特征图的有效宽高比例  [bs,4(num_levels),2(长和宽)]
        valid_ratios = torch.stack(
            [self.get_valid_ratio(m) for m in mlvl_masks], 1)

获取reference point

通过下面的函数获取reference point,最后得到的reference point是在0-1尺度上的值

    def get_reference_points(spatial_shapes, valid_ratios, device):
        """Get the reference points used in decoder.

        Args:
            spatial_shapes (Tensor): The shape of all
                feature maps, has shape (num_level, 2).
            valid_ratios (Tensor): The radios of valid
                points on the feature map, has shape
                (bs, num_levels, 2)
            device (obj:`device`): The device where
                reference_points should be.

        Returns:
            Tensor: reference points used in decoder, has \
                shape (bs, num_keys, num_levels, 2).
        """
        reference_points_list = []
        for lvl, (H, W) in enumerate(spatial_shapes):
            #  TODO  check this 0.5
            # 获取每个reference point中心横纵坐标,加减0.5是确保每个初始点是在每个pixel的中心
            ref_y, ref_x = torch.meshgrid(
                torch.linspace(
                    0.5, H - 0.5, H, dtype=torch.float32, device=device),
                torch.linspace(
                    0.5, W - 0.5, W, dtype=torch.float32, device=device))
            # 将横纵坐标进行归一化
            ref_y = ref_y.reshape(-1)[None] / (
                valid_ratios[:, None, lvl, 1] * H)
            ref_x = ref_x.reshape(-1)[None] / (
                valid_ratios[:, None, lvl, 0] * W)
            ref = torch.stack((ref_x, ref_y), -1)
            reference_points_list.append(ref)
        reference_points = torch.cat(reference_points_list, 1)
        # 将参考点的位置映射到有效区域
        reference_points = reference_points[:, :, None] * valid_ratios[:, None]
        return reference_points

encoder

  memory = self.encoder(
    query=feat_flatten, # 输入query,是展平后的多尺度feature map [所有H*W的和, bs, 256]
    key=None,     #在self attention中,k和v是由q算出,因此输入为None
    value=None,
    query_pos=lvl_pos_embed_flatten, #输入query的位置编码, [所有H*W的和, bs, 256]
    query_key_padding_mask=mask_flatten, # padding mask [bs, 所有H*W的和]
    spatial_shapes=spatial_shapes, #每层feature map的h和w  [num_levels, bs]
    reference_points=reference_points, #[bs, 所有H*W的和, num_levels, 2]
    level_start_index=level_start_index,# 每层feature map展平后的第一个元素的位置索引 [num_levels]
    valid_ratios=valid_ratios, # 每层feature map对应的mask中有效的宽高比 [B, num_levels, 2]
    **kwargs)
# memory:encoder的输出,经过自注意力后的多尺度feature map [所有H*W的和, bs, 256]

进入encoder之后会按照在配置文件中的的顺序来

encoder=dict(
      type='DetrTransformerEncoder',
      num_layers=6,
      transformerlayers=dict(
      type='BaseTransformerLayer',
      attn_cfgs=dict(
      type='MultiScaleDeformableAttention', embed_dims=256),
      feedforward_channels=1024,
      ffn_dropout=0.1,
      operation_order=('self_attn', 'norm', 'ffn', 'norm'))),

这里的self-attn变成了MultiScaleDeformableAttention,
MultiScaleDeformableAttention的代码如下:在mmcv\ops\multi_scale_deform_attn.py中

        if value is None:
            value = query

        if identity is None:
            identity = query
        if query_pos is not None:
            query = query + query_pos
        if not self.batch_first:
            # change to (bs, num_query ,embed_dims)
            query = query.permute(1, 0, 2)
            value = value.permute(1, 0, 2)

        bs, num_query, _ = query.shape
        bs, num_value, _ = value.shape
        assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
		# value的值是从query中学到的,最开始的value为None,被赋值为query,然后通过一个线性层得到真正的value [bs,所有H*W的和,256]
        value = self.value_proj(value)
        if key_padding_mask is not None:
            value = value.masked_fill(key_padding_mask[..., None], 0.0)
        #[bs,所有H*W的和,256] ---> [bs,所有H*W的和,8,32]
        value = value.view(bs, num_value, self.num_heads, -1)
'''
self.sampling_offsets:
Linear(in_features=256, out_features=256, bias=True)
self.attention_weights:
Linear(in_features=256, out_features=128, bias=True)
'''

        # sampling_offsets : [bs,所有H*W的和, 8, 4, 4, 2]
        sampling_offsets = self.sampling_offsets(query).view(
            bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
        # attention_weights:[1, 10458, 8, 16]
        attention_weights = self.attention_weights(query).view(
            bs, num_query, self.num_heads, self.num_levels * self.num_points)
         # 为啥要softmax?
         # 经过一个线性层映射+softmax得到每个query的注意力权重
        attention_weights = attention_weights.softmax(-1)
		 #[1, 所有H*W的和, 8, 16] ---> [1,所有H*W的和,8,4,4]
        attention_weights = attention_weights.view(bs, num_query,
                                                   self.num_heads,
                                                   self.num_levels,
                                                   self.num_points)
        if reference_points.shape[-1] == 2:
    # 首先是sampling_offsets / offset_normalizer进行归一化 然后再和reference_points相加
            offset_normalizer = torch.stack(
                [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
            sampling_locations = reference_points[:, :, None, :, None, :] \
                + sampling_offsets \
                / offset_normalizer[None, None, None, :, None, :]
        elif reference_points.shape[-1] == 4:
            sampling_locations = reference_points[:, :, None, :, None, :2] \
                + sampling_offsets / self.num_points \
                * reference_points[:, :, None, :, None, 2:] \
                * 0.5
        else:
            raise ValueError(
                f'Last dim of reference_points must be'
                f' 2 or 4, but get {reference_points.shape[-1]} instead.')
                
         # 调用cuda算子进行deformable atten
        if torch.cuda.is_available() and value.is_cuda:
            output = MultiScaleDeformableAttnFunction.apply(
                value, spatial_shapes, level_start_index, sampling_locations,
                attention_weights, self.im2col_step)
        else:
            output = multi_scale_deformable_attn_pytorch(
                value, spatial_shapes, sampling_locations, attention_weights)

        output = self.output_proj(output)

        if not self.batch_first:
            # (num_query, bs ,embed_dims)
            output = output.permute(1, 0, 2)
        # 这个identity是上一次的query
        return self.dropout(output) + identity

在做完multi_scale_deformable_attn之后,会进行norm,ffn,norm,这样一个encoder layer就走完了,这个过程将重复6次,最后返回到DeformableDetrTransformer的forward中,返回值memory为encoder的输出,也即经过multi_scale_deformable_attn后的多尺度feature map,其维度为:[所有H*W的和, bs, 256]

decoder

inter_states, inter_references = self.decoder(
            query=query, # [num_query,bs,256]
            key=None,
            value=memory,  # encoder的输出 经过encoder后的feature map
            query_pos=query_pos,
            key_padding_mask=mask_flatten,
            reference_points=reference_points,
            spatial_shapes=spatial_shapes,
            level_start_index=level_start_index,
            valid_ratios=valid_ratios,
            reg_branches=reg_branches,
            **kwargs)
        query_pos, query = torch.split(query_embed, c, dim=1)
        query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) #[bs,300,256]
        query = query.unsqueeze(0).expand(bs, -1, -1)#[bs,300,256]
        # 将query_pos经过一次线性变换+sigmoid正好能作为初始参考点坐标
        reference_points = self.reference_points(query_pos).sigmoid()
        init_reference_out = reference_points

        # decoder
        query = query.permute(1, 0, 2) #[300(num_query),bs,256]
        memory = memory.permute(1, 0, 2) #[所有H*W的和,bs,256]
        query_pos = query_pos.permute(1, 0, 2)#[300(num_query),bs,256]
        inter_states, inter_references = self.decoder(
            query=query, #[300(num_query),bs,256]
            key=None,
            value=memory,#经过encoder的feature map
            query_pos=query_pos,
            key_padding_mask=mask_flatten,
            reference_points=reference_points,
            spatial_shapes=spatial_shapes,
            level_start_index=level_start_index,
            valid_ratios=valid_ratios,
            reg_branches=reg_branches, #None
            **kwargs)

进入到self.decoder中后,代码跳转到DeformableDetrTransformerDecoder中的forward函数中,在mmdetection/mmdet/models/utils/transformer.py中

		output = query
        intermediate = [] #存储每层decoder layer的query
        intermediate_reference_points = [] # 用来存储每层decoder layer的reference_points
        for lid, layer in enumerate(self.layers):
            if reference_points.shape[-1] == 4:
                reference_points_input = reference_points[:, :, None] * \
                    torch.cat([valid_ratios, valid_ratios], -1)[:, None]
            else:
                assert reference_points.shape[-1] == 2
                reference_points_input = reference_points[:, :, None] * \
                    valid_ratios[:, None]
            output = layer(
                output,   # query
                *args,
                reference_points=reference_points_input,
                **kwargs) 
# kwargs包含了['key', 'value', 'query_pos', 'key_padding_mask', 'spatial_shapes', 'level_start_index']
# key为None ,value为从encoder中得到的memory
            output = output.permute(1, 0, 2)
            # reg_branches默认问None
            if reg_branches is not None:
                tmp = reg_branches[lid](output)
                if reference_points.shape[-1] == 4:
                    new_reference_points = tmp + inverse_sigmoid(
                        reference_points)
                    new_reference_points = new_reference_points.sigmoid()
                else:
                    assert reference_points.shape[-1] == 2
                    new_reference_points = tmp
                    new_reference_points[..., :2] = tmp[
                        ..., :2] + inverse_sigmoid(reference_points)
                    new_reference_points = new_reference_points.sigmoid()
                reference_points = new_reference_points.detach()

            output = output.permute(1, 0, 2)
            # 将中间的query和reference_point存下来,query有更新,reference_points其实每一层都是一样的
            if self.return_intermediate:
                intermediate.append(output)
                intermediate_reference_points.append(reference_points)

        if self.return_intermediate:  # true
            return torch.stack(intermediate), torch.stack(
                intermediate_reference_points)

        return output, reference_points

decoder最后返回两个值,也即所有六层decoder的query和reference_points,每一层的query是不同的,但是每一层的referen_points是相同的
在这里插入图片描述
最后整个transformer返回三个值,inter_states,init_reference_out,inter_references_out
inter_states :[num_dec_layers, bs, num_query, embed_dims] 表示每个decode layer的query
init_reference_out : [bs,num_query,2] 表示最开始的reference_points
inter_references_out:[num_dec_layers, bs, num_query, embed_dims] 表示每一层的reference points
在这里插入图片描述

预测部分

在经过了transformer部分之后,代码回到了deformable detr head中

		hs = hs.permute(0, 2, 1, 3)
        outputs_classes = []
        outputs_coords = []
		
		# 逐个decoder layer去做预测
        for lvl in range(hs.shape[0]):
            if lvl == 0:
                reference = init_reference
            else:
                reference = inter_references[lvl - 1]
            reference = inverse_sigmoid(reference)  # 做反sigmoid
            outputs_class = self.cls_branches[lvl](hs[lvl])
            # 这里预测出的tmp是相对于reference的offset
            tmp = self.reg_branches[lvl](hs[lvl])
            if reference.shape[-1] == 4:
                tmp += reference
            else:
                assert reference.shape[-1] == 2
                tmp[..., :2] += reference   #reference与预测出的offset相加
            outputs_coord = tmp.sigmoid()
            outputs_classes.append(outputs_class)
            outputs_coords.append(outputs_coord)

        outputs_classes = torch.stack(outputs_classes)
        outputs_coords = torch.stack(outputs_coords)
        if self.as_two_stage:
            return outputs_classes, outputs_coords, \
                enc_outputs_class, \
                enc_outputs_coord.sigmoid()
        else:
            return outputs_classes, outputs_coords, \
                None, None

后面就是计算loss了,这部分和DETR应该是一样的,我在DETR的源码阅读中已经写过了,这里就不写了,感兴趣的可以去看我的另一篇博客:DETR源码阅读

一些细节:

encoder时候的只有self_atten,QKV都是feature map
decoder时候,self_atten时候,QKV都是object query([num_query,bs,256])
cross_atten时候,Q是object query V是feature map,K这里是None,因为deformable atten不需要通过Q点乘K来获取attention_weight,其attention_weight是通过object query学出来的

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Deformable Detr代码阅读 的相关文章

  • Python - Pandas - 将特定函数应用于给定级别 - 多索引数据帧

    我有一个多索引 DataFrame 并且我希望仅对分配给我的级别之一的向量应用一些计算 使用下面的代码 import pandas as pd import datetime ticker date US datetime date tod
  • ValueError:无法将 DatetimeIndex 转换为 dtype datetime64[us]

    我正在尝试为 S P 500 ETF 创建一个包含 30 分钟数据的 PostgreSQL 表 spy30new 用于测试新插入的数据 来自具有 15 分钟数据的多个股票的表 全部 15 个 all15 在 dt 时间戳 和 instr 股
  • 使用 python 将 bibtex 文件转换为 html (也许是 pybtex?)

    您好 我想解析 bibtex 出版物文件并对特定字段 例如年份 进行排序并过滤某些内容 然后将其放在网站上 我遇到了 pybtex 它可以读取和解析 bibtex 文件 但它基本上没有记录 我不知道如何对条目进行排序 pybtex 是可行的
  • 回归模型 statsmodel python

    这更多是一个统计问题 因为代码运行良好 但我正在学习 python 中的回归建模 我在下面使用 statsmodel 编写了一些代码来创建一个简单的线性回归模型 import statsmodels api as sm import num
  • 使用 PyQt 和 matplotlib 在可滚动小部件中显示多个绘图

    由于我没有得到答案this https stackoverflow com questions 12179893 creating a scrollable multiplot with pythons pylab我尝试用 PyQt 解决这
  • lxml/python 使用 CDATA 部分读取 xml

    在我的 xml 中我有一个CDATA部分 我想保留 CDATA 部分 然后剥离它 有人可以帮忙解决以下问题吗 默认不起作用 from io import StringIO from lxml import etree xml
  • 使用 isdigit 表示浮点数?

    a raw input How much is 1 share in that company while not a isdigit print You need to write a number n a raw input How m
  • if(interactive()) 是否相当于 Python 中的“if __name__ == ”__main__“: main()”?

    我希望 R 脚本有一个 main 函数 可以在交互模式下执行 但在获取文件时不应执行 main 函数 已经有一个关于这个的问题了 https stackoverflow com questions 2968220 is there an r
  • 如何使用httplib2进行相互证书认证

    我正在使用 httplib2 从我的服务器向另一个 Web 服务发出请求 我们想要使用相互证书身份验证 我了解如何使用证书进行传出连接 h set certificate 但是如何检查应答服务器使用的证书 这张票 http code goo
  • 如何检测斑点并将其裁剪成 png 文件?

    我一直在开发一个网络应用程序 我陷入了一个有问题的问题 我会尝试解释我想要做什么 在这里您看到第一个大图像 其中有绿色形状 我想要做的是将这些形状裁剪成不同的 png 文件 并使它们的背景透明 就像大图像下面的示例裁剪图像一样 第一张图像将
  • python中的unicode错误[关闭]

    很难说出这里问的是什么 这个问题是含糊的 模糊的 不完整的 过于宽泛的或修辞性的 无法以目前的形式得到合理的回答 如需帮助澄清此问题以便重新打开 访问帮助中心 help reopen questions 在下面的代码中我收到错误mailSe
  • 使用 matplotlib 在 python3 中对多个形状进行动画处理

    尝试在 python3 中使用 matplotlib 动画函数同时对多个对象进行动画处理 下面写的代码是我到目前为止的位置 我能够创建多个对象并将它们显示在图中 我通过使用包含矩形补丁函数的 for 循环来完成此操作 从这里开始 我希望通过
  • 在视图之间共享并在 AppConfig 中初始化的变量

    我想要一个在应用程序启动时初始化的变量 并且可以从视图访问该变量 my app my config py class WebConfig AppConfig name verbose name def ready self print lo
  • 数据框更新后如何刷新绘图?

    假设您已经使用以下方法构建了一个图形px line 使用数据框 数据框稍后会添加新数据 用新数据刷新数据的好方法是什么 一个例子可以是px data stocks 从列的子集开始 GOOG AAPL AMZN FB NFLX MSFT 例如
  • Python Sqlite3 获取 Sqlite 连接路径

    给定一个 sqlite3 连接对象 如何检索 sqlite3 文件的文件路径 The Python 连接对象 http github com python cpython blob master Modules sqlite connect
  • Django ImageField 默认值

    模型 py class UserProfile models Model photo models ImageField upload to get upload file name storage OverwriteStorage def
  • 如何使用Python3.4在tornado中进行异步mysql操作?

    我现在使用Python3 4 我想在Tornado中使用异步mysql客户端 我已经发现torndb https github com bdarnell torndb但在阅读其源代码后 我认为它无法进行异步mysql操作 因为它只是封装了M
  • 如何从 Django 中的链接设置预定义的表单值?

    我的项目是这样布局的 1 page has many categories 2 category belongs to page has many items 3 item belongs to category 当我进入一个页面时 我想修
  • 加入语音频道(discord.py)

    当我尝试让我的机器人加入我的语音频道时 出现以下错误 await client join voice channel voice channel 产生错误的行 Traceback most recent call last File usr
  • Python 单元测试:Nose 失败时重试?

    我有一个随机失败的测试 我想让它在发送错误消息之前重试多次 我将 python 与 Nose 一起使用 我写了以下内容 但不幸的是 即使使用 try except 处理 当第一次尝试测试失败时 Nose 也会返回错误 def test so

随机推荐