【MMDet Note】MMDetection中Neck之FPN代码理解与解读

2023-11-15


前言

mmdetection/mmdet/models/necks/fpn.py中FPN类的个人理解与解读。


一、总概

本文以mmdetection/configs/base/models/retinanet_r50_fpn.py中的RetinaNet配置参数为例进行分析。
以下是RetinaNet模型的Neck参数配置:

neck=dict(
	type='FPN',
    # in_channal对应ResNet输出的4个尺度特征图channel数
	in_channels=[256, 512, 1024, 2048],
    # FPN 输出的每个尺度输出特征图通道
	out_channels=256,
    # in_channels对应的特征图从index=1开始用,即FPN用了后三个特征图
	start_level=1,
    # 额外输出层的特征图来源
	add_extra_convs='on_input',
    # FPN 输出特征图个数为5, stride = 8,16,32,64,128
	num_outs=5),

RetinaNet整体模型的大概构造如下图所示:
在这里插入图片描述

二、代码解读

1.FPN类

在这里插入图片描述
代码的标注#都是以RetinaNet的config为例的哦~~代码解读与图片中的内容是互相对应的!!!

@NECKS.register_module()
class FPN(BaseModule):
    def __init__(self,
                 in_channels,             # RetinaNet为例 [256, 512, 1024, 2048]
                 out_channels,            # 256
                 num_outs,                # 5
                 start_level=0,           # 1
                 end_level=-1,
                 add_extra_convs=False,   # 'on_input'
                 relu_before_extra_convs=False,
                 no_norm_on_lateral=False,
                 conv_cfg=None,
                 norm_cfg=None,
                 act_cfg=None,
                 upsample_cfg=dict(mode='nearest'),
                 init_cfg=dict(
                     type='Xavier', layer='Conv2d', distribution='uniform')):
        super(FPN, self).__init__(init_cfg)
        assert isinstance(in_channels, list)
        self.in_channels = in_channels                              # self.in_channels = [256, 512, 1024, 2048]
        self.out_channels = out_channels                            # self.out_channels = 256    对应图中M3-M5的channel数为256
        self.num_ins = len(in_channels)                             # self.num_ins = 4
        self.num_outs = num_outs                                    # self.num_outs = 5     对应图中P3-P7
        # 下面4个参数对于结构理解关系不大
        self.relu_before_extra_convs = relu_before_extra_convs
        self.no_norm_on_lateral = no_norm_on_lateral
        self.fp16_enabled = False
        self.upsample_cfg = upsample_cfg.copy() # 上采样参数

        if end_level == -1 or end_level == self.num_ins - 1:
            self.backbone_end_level = self.num_ins                  # self.backbone_end_level = 4
            assert num_outs >= self.num_ins - start_level
        else:
            # if end_level is not the last level, no extra level is allowed
            self.backbone_end_level = end_level + 1
            assert end_level < self.num_ins
            assert num_outs == end_level - start_level + 1
        self.start_level = start_level                              # self.start_level = 1
        self.end_level = end_level                                  # self.end_level = -1
        self.add_extra_convs = add_extra_convs                      # self.add_extra_convs = 'on_input'
        assert isinstance(add_extra_convs, (str, bool))
        if isinstance(add_extra_convs, str):
            # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
            assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
        elif add_extra_convs:  # True
            self.add_extra_convs = 'on_input'

        
        self.lateral_convs = nn.ModuleList()        # 对应图中橙色虚线框
        self.fpn_convs = nn.ModuleList()            # 对应图中绿色虚线框

        for i in range(self.start_level, self.backbone_end_level):    # start_level = 1, backbone_end_level = 4,整体数量为3
            # 构造conv 1x1,对应图中3个橙色矩阵
            l_conv = ConvModule(
                in_channels[i],
                out_channels,
                1,      # kernel_size = 1
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
                act_cfg=act_cfg,
                inplace=False)
            # 构造conv 3x3,对应图中3个绿色矩阵
            fpn_conv = ConvModule(
                out_channels,
                out_channels,
                3,
                padding=1,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=act_cfg,
                inplace=False)

            self.lateral_convs.append(l_conv)
            self.fpn_convs.append(fpn_conv)

        # 添加额外的conv level (e.g., RetinaNet)
        extra_levels = num_outs - self.backbone_end_level + self.start_level    # extra_levels = 5 - 4 + 1 = 2  
        # 其实不论怎么样这个extra_levels都会>=1(当前理解的也就是,在默认情况下图中的Output中的绿色矩形始终存在)
        if self.add_extra_convs and extra_levels >= 1:
            for i in range(extra_levels):    # 2
                if i == 0 and self.add_extra_convs == 'on_input':                # 当i == 0时,满足条件
                    in_channels = self.in_channels[self.backbone_end_level - 1]  # 当i == 0时,in_channels = in_channels[3] 也即2048,此时构造的对应图中紫色的矩阵
                else:                                                            # 当i == 0时,in_channels = 256
                    in_channels = out_channels
                # 构造conv 3x3, stride=2
                extra_fpn_conv = ConvModule(
                    in_channels,
                    out_channels,
                    3,
                    stride=2,
                    padding=1,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    act_cfg=act_cfg,
                    inplace=False)
                self.fpn_convs.append(extra_fpn_conv)
        # 因此RetinaNet最终fpn_convs中有5块Conv块,即对应图中绿色虚线框关联的内容有5块

2.def forward

这里重新贴一下上面的图,代码解读与图片中的内容是互相对应的!!!
在这里插入图片描述

    @auto_fp16()
    def forward(self, inputs):
        """Forward function."""
        assert len(inputs) == len(self.in_channels)

        # laterals 用来记录每一次计算后的输出值,可以理解成是一个临时变量temp
        laterals = [
            lateral_conv(inputs[i + self.start_level])              # self.start_level = 1,inputs[i + 1]为C3-C5的输入
            for i, lateral_conv in enumerate(self.lateral_convs)
        ]
        # 此时,laterals 已经记录了C3-C5经过conv 1x1之后得到的M3-M5值(还未upsample)
        
        # build top-down path
        used_backbone_levels = len(laterals)                # 3
        for i in range(used_backbone_levels - 1, 0, -1):    # i in [2,1]
            # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
            #  it cannot co-exist with `size` in `F.interpolate`.
            if 'scale_factor' in self.upsample_cfg:
                # fix runtime error of "+=" inplace operation in PyTorch 1.10
                laterals[i - 1] = laterals[i - 1] + F.interpolate(
                    laterals[i], **self.upsample_cfg)
            else:
                # 这里也就是upsample与相加的操作,可以理解成经过“upsample”与“+”的操作后,才得到真正的M3-M5的值
                prev_shape = laterals[i - 1].shape[2:]
                laterals[i - 1] = laterals[i - 1] + F.interpolate(
                    laterals[i], size=prev_shape, **self.upsample_cfg)
        # 此时,laterals 记录了经过upsample之后得到的新M3-M5值


        # 建立 outputs
        # part 1: from original levels 此处out对应P3-P5
        outs = [
            self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)   # used_backbone_levels = 3
        ]
        # part 2: add extra levels
        if self.num_outs > len(outs):       # self.num_outs = 5
            # use max pool to get more levels on top of outputs
            # (e.g., Faster R-CNN, Mask R-CNN)
            if not self.add_extra_convs:     # self.add_extra_convs = 'on_input'
                for i in range(self.num_outs - used_backbone_levels):
                    outs.append(F.max_pool2d(outs[-1], 1, stride=2))
            # add conv layers on top of original feature maps (RetinaNet)
            else:
                if self.add_extra_convs == 'on_input':             # 满足条件
                    extra_source = inputs[self.backbone_end_level - 1]  # self.backbone_end_level - 1 = 3 , extra_source 对应图中的C5
                elif self.add_extra_convs == 'on_lateral':
                    extra_source = laterals[-1]
                elif self.add_extra_convs == 'on_output':
                    extra_source = outs[-1]
                else:
                    raise NotImplementedError
                # 此处outs增加P6
                outs.append(self.fpn_convs[used_backbone_levels](extra_source))   # self.fpn_convs[used_backbone_levels]对应图中紫色的矩阵
                for i in range(used_backbone_levels + 1, self.num_outs): # i in [4]
                    if self.relu_before_extra_convs:
                        outs.append(self.fpn_convs[i](F.relu(outs[-1])))
                    else:
                        # 此处out增加P7
                        outs.append(self.fpn_convs[i](outs[-1]))  # self.fpn_convs[i]对应con3x3,stride=2     outs[-1]对应P6     这里也对应了之前提到的“在默认情况下图中的Output中的绿色矩形始终存在”
        return tuple(outs)


总结

本文仅代表个人理解,若有不足,欢迎批评指正。

参考:
【夜深人静读MM】MMdetection框架之Neck中的FPN解读
轻松掌握 MMDetection 中常用算法(一):RetinaNet 及配置详解

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【MMDet Note】MMDetection中Neck之FPN代码理解与解读 的相关文章

随机推荐

  • 了解预训练以及在自编码器中的应用

    预训练是一种机器学习技术 在这种技术中 模型被训练以在标注数据少或不存在的情况下自动从未标记的数据中学习 预训练可以为模型提供先验知识 使其能够在特定任务上更好地泛化 预训练过程通常分为两个阶段 无监督预训练和有监督微调 无监督预训练 模型
  • unity屏幕后处理Bloom优化(光晕)

    前言 前几天看米哈游的技术总监说 崩坏3 的bloom效果的实现是 1 高亮像素过滤 2 向下采样 降采样 3 向上采样 4 将模糊后的图像和原图像混合 经过上面的步骤 能高效的实现bloom效果 常规的bloom是使用 提取高亮 卷积滤波
  • [专利与论文-20]:江苏省南京市2022年电子信息申报操作指南

    1 学时认定 每年公需课不能低于30学时 2 流程
  • elastic search中易并行聚合算法,三角选择原则,近似聚合算法浅析

    1 有些聚合分析的算法 是很容易就可以并行的 比如说max 有些聚合分析的算法 是不好并行的 比如说 count distinct 并不是说 在每个node上 直接就出一些distinct value 就可以的 因为数据可能会很多 es会采
  • DMX512协议是什么 DMX512数字灯光控制系统介绍

    基于DMX512控制协议进行调光控制的灯光系统叫做数字灯光系统 目前 包括电脑灯在内的各种舞台效果灯 调光控制器 控制台 换色器 电动吊杆等各种舞台灯光设备 以其对DMX512协议的全面支持 已全面实现调光控制的数字化 并在此基础上 逐渐趋
  • 74HC595 使用记录 国产UTC品牌

    芯片型号 U74HC595A 数据手册时序图 实际测试时序图 通道1 595的14脚 通道2 595 的11脚 通道3 595 的9脚 结论 U74HC595A 国产 UTC品牌 数据手册与实测数据不一致
  • CentOS 7.9 64位 SCC版安装FastDfs和配置Nginx

    最近练习的项目中需要用到FastDfs 和Nginx 这里记录一下安装和配置过程 个人使用部署过程遇到了很多的坑 准备把过程记下来不然忘了 首先 购买 试用阿里云 CentOS 7 9 64位Scc版系统 进入远程桌面 由于项目较老 所以我
  • 尚硅谷电影推荐系统搭建遇到的问题及知识

    尚硅谷电影推荐系统搭建遇到的问题及知识 Hadoop ES问题 Zookeeper Flume ng Kafka Azkaban 其他 腾讯云Superset问题 需更新数据库用户 登录master节点 cd usr local servi
  • java去掉字符串的逗号_java – 从字符串数组中删除逗号

    我想执行像这样的查询 从 xyz DB 中选择ID test 其中用户在 a b 所以相应的代码就像 String s for String user selUsers s user s 从test中选择ID 其中userId在s中 以下代
  • idea中 关于thymeleaf 变量 在html中 报红 以及控制器 返回页面无法追踪的问题

    html页面thymeleaf 的 变量 报红 无法追踪 controller 无法直接追踪 页面 默认配置前缀 templates 后缀 html 可以正常运行 页面跳转以及变量的传递 就是看着有点不舒服 咋办呢 我无意之间发现的 加入s
  • JVM学习笔记

    目录 垃圾回收器 垃圾回收器分类 按线程数分 按工作模式分 按碎片处理方式分 按工作的内存区间分 GC分类与性能指标 性能指标 吞吐量 性能指标 暂停时间 吞吐量vs暂停时间 垃圾回收器 垃圾回收器发展史 7种经典的垃圾收集器 垃圾回收器的
  • [人工智能-综述-3]:人工智能与硅基生命,人类终将成为造物主

    作者主页 文火冰糖的硅基工坊 https blog csdn net HiWangWenBing 本文网址 https blog csdn net HiWangWenBing article details 119061112 目录 引言
  • 145 - Table ' is marked as crashed and should be repai

    145 Table schoolhelp xyb user is marked as crashed and should be repai 145 表 schoolhelp xyb user 被标记为崩溃 应重新修 修复方式 repair
  • Html CSS学习(六)background-position背景图像的定位

    2019独角兽企业重金招聘Python工程师标准 gt gt gt Html CSS学习 六 background position背景图像的定位 在网页中 会有很多的背景图像与一些小的图标等内容 在初学的时候 为了达到页面的效果 都是将原
  • Spring Boot中如何编写优雅的单元测试

    单元测试是指对软件中的最小可测试单元进行检查和验证 在Java中 单元测试的最小单元是类 通过编写针对类或方法的小段代码 来检验被测代码是否符合预期结果或行为 执行单元测试可以帮助开发者验证代码是否正确实现了功能需求 以及是否能够适应应用环
  • Log4j2之JNDI注入(CVE-2021-44228)

    前言 首先要了解什么是Log4j2 Log4j2是一个Java日志组件 主要用于对日志的记录 这次漏洞出现在Log4j2的Lookup功能 使用Lookup可以在日志中添加动态的值 这些变量可以是外部环境变量 也可以是MDC中的变量 还可以
  • 海量数据库(详解缓存处理方法)

    缓存处理大数据 缓存就是将从数据库中获取的结果暂时保存起来在下次使用的时候无需重新到数据库中获取 从而降低数据库的压力 缓存的使用方式可以分为通过程序直接将数据库数据保存到内存中和使用缓存框架两种方式 它主要用于数据变化不是很频繁的情况 而
  • OR36 链表的回文结构

    OR36 链表的回文结构 较难 通过率 29 47 时间限制 3秒 空间限制 32M 知识点 链表栈 描述 对于一个链表 请设计一个时间复杂度为O n 额外空间复杂度为O 1 的算法 判断其是否为回文结构 给定一个链表的头指针A 请返回一个
  • python中抽象类和抽象方法_在Python中定义和使用 抽象类及抽象方法 抽象属性

    原文链接 http www jb51 net article 87710 htm 本文根据自己的理解和思考 对原文略有改动 Python中我们可以使用abc模块来构建抽象类 在讲抽象类之前 先说下抽象方法的实现 抽象方法是基类中定义的方法
  • 【MMDet Note】MMDetection中Neck之FPN代码理解与解读

    文章目录 前言 一 总概 二 代码解读 1 FPN类 2 def forward 总结 前言 mmdetection mmdet models necks fpn py中FPN类的个人理解与解读 一 总概 本文以mmdetection co