DeformableConv(可形变卷积)理论和代码分析

2023-11-15

DeformableConv(可形变卷积)理论和代码分析

1.DeformableConv理论分析

普通卷积示例
以上图为例进行讲解,在普通卷积中:

  • input输入尺寸为[batch_size, channel, H, W]
  • output输入尺寸也为[batch_size, channel, H, W]
  • kernel尺寸为[kernel_szie, kernel_szie]
    因此output中任意一点p0,对应到input中的卷积采样区域大小为kernel_szie x kernel_szie,卷积操作用公式可表示为:
    y ( p 0 ) = ∑ p n ∈ R w ( p n ) ⋅ x ( p 0 + p n ) \mathbf{y}\left(\mathbf{p}_{0}\right)=\sum_{\mathbf{p}_{n} \in \mathcal{R}} \mathbf{w}\left(\mathbf{p}_{n}\right) \cdot \mathbf{x}\left(\mathbf{p}_{0}+\mathbf{p}_{n}\right) y(p0)=pnRw(pn)x(p0+pn)
    其中, p n \mathbf{p}_{n} pn代表卷积核中每一个点相对于中心点的偏移量,可用如下公式表示(3 x 3卷积核为例):
    R = { ( − 1 , − 1 ) , ( − 1 , 0 ) , … , ( 0 , 1 ) , ( 1 , 1 ) } \mathcal{R}=\{(-1,-1),(-1,0), \ldots,(0,1),(1,1)\} R={(1,1),(1,0),,(0,1),(1,1)}
    在这里插入图片描述

w ( p n ) \mathbf{w}\left(\mathbf{p}_{n}\right) w(pn)代表卷积核上对应位置的权重。 p 0 \mathbf{p}_{0} p0可以看做output上每一个点, y ( p 0 ) \mathbf{y}\left(\mathbf{p}_{0}\right) y(p0)为output上每一点的具体值。 x ( p 0 + p n ) \mathbf{x}\left(\mathbf{p}_{0}+\mathbf{p}_{n}\right) x(p0+pn)为output上每个点对应到input上的卷积采样区域的具体值。该公式整体意思就是卷积操作。

而DeformableConv的计算步骤就是在普通卷积的基础上,加一个模型自己学习的偏移量offset,公式为:
y ( p 0 ) = ∑ p n ∈ R w ( p n ) ⋅ x ( p 0 + p n + Δ p n ) \mathbf{y}\left(\mathbf{p}_{0}\right)=\sum_{\mathbf{p}_{n} \in \mathcal{R}} \mathbf{w}\left(\mathbf{p}_{n}\right) \cdot \mathbf{x}\left(\mathbf{p}_{0}+\mathbf{p}_{n}+\Delta \mathbf{p}_{n}\right) y(p0)=pnRw(pn)x(p0+pn+Δpn)
公式中用 Δ p n \Delta \mathbf{p}_{n} Δpn表示偏移量。需要注意的是,该偏移量是针对 x \mathbf{x} x的,也就是可变形卷积变的不是卷积核,而是input。
在这里插入图片描述
从上图可以看到,在input上按照普通卷积操作,output上的一点对应到input上的卷积采样区域是一个卷积核大小的正方形,而可变形卷积对应的卷积采样区域为一些蓝框表示的点,这就是可变形卷积与普通卷积的区别。

接下来说一下可变形卷积的具体细节,以N x N的卷积核为例。一个output上的点对应到input上的卷积采样区域大小为N x N,按照可变形卷积的操作,这N x N区域的每一个卷积采样点都要学习一个偏离量offset,而offset是用坐标表示的,所以一个output要学习2N x N个参数。一个output大小为H x W,所以一共要学习2NxN x H x W个参数。即上图的offset field,其维度为B x 2NxN x H x W,其中B代表batch_size,(上图是网络上找的,里面的N指卷积核的面积)。值得注意的两点细节是:

  • input(假设维度为B x C x H x W)一个batch内的所有通道上的特征图(一共C个)共用一个offset field,即一个batch内的每张特征图用到的偏移量是一样的。
  • 可变形卷积不改变input的尺寸,所以output也为H x W。

一个点加上offset之后,大概率不会得到一个整数坐标点,这时就应该进行线性插值操作,具体做法为找出和偏移后的点距离小于1且最近的四个点,将其值乘以权重(权重以距离来衡量)再进行加和操作,即为最终偏移后的点的值。

2.DeformableConv代码分析

依照上述理论,可将DeformableConv的代码实现拆解为如下流程图:
在这里插入图片描述

2.1初始化操作

class DeformConv2d(nn.Module):
    def __init__(self, 
                 inc, 
                 outc, 
                 kernel_size=3, 
                 padding=1, 
                 stride=1, 
                 bias=None, 
                 modulation=False):
        """
        Args:
            modulation (bool, optional): If True, Modulated Defomable 
            Convolution (Deformable ConvNets v2).
        """
        super(DeformConv2d, self).__init__()
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.zero_padding = nn.ZeroPad2d(padding)
        self.conv = nn.Conv2d(inc, #该卷积用于最终的卷积
                              outc, 
                              kernel_size=kernel_size, 
                              stride=kernel_size, 
                              bias=bias)

        self.p_conv = nn.Conv2d(inc, #该卷积用于从input中学习offset
                                2*kernel_size*kernel_size, 
                                kernel_size=3, 
                                padding=1, 
                                stride=stride)
        nn.init.constant_(self.p_conv.weight, 0)
        self.p_conv.register_backward_hook(self._set_lr)

        self.modulation = modulation #该部分是DeformableConv V2版本的,可以暂时不看
        if modulation:
            self.m_conv = nn.Conv2d(inc, 
                                    kernel_size*kernel_size, 
                                    kernel_size=3, 
                                    padding=1, 
                                    stride=stride)
            nn.init.constant_(self.m_conv.weight, 0)
            self.m_conv.register_backward_hook(self._set_lr)

    @staticmethod
    def _set_lr(module, grad_input, grad_output):
        grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
        grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))

2.2执行

    def forward(self, x):
        offset = self.p_conv(x) #此处得到offset
        if self.modulation:
            m = torch.sigmoid(self.m_conv(x))

        dtype = offset.data.type()
        ks = self.kernel_size
        N = offset.size(1) // 2

        if self.padding:
            x = self.zero_padding(x)

        # (b, 2N, h, w)
        p = self._get_p(offset, dtype)

        # (b, h, w, 2N)
        p = p.contiguous().permute(0, 2, 3, 1)
        q_lt = p.detach().floor()
        q_rb = q_lt + 1

        q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
        q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
        q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
        q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)

        # clip p
        p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)

        # bilinear kernel (b, h, w, N)
        g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
        g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
        g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
        g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))

        # (b, c, h, w, N)
        x_q_lt = self._get_x_q(x, q_lt, N)
        x_q_rb = self._get_x_q(x, q_rb, N)
        x_q_lb = self._get_x_q(x, q_lb, N)
        x_q_rt = self._get_x_q(x, q_rt, N)

        # (b, c, h, w, N)
        x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
                   g_rb.unsqueeze(dim=1) * x_q_rb + \
                   g_lb.unsqueeze(dim=1) * x_q_lb + \
                   g_rt.unsqueeze(dim=1) * x_q_rt

        # modulation
        if self.modulation:
            m = m.contiguous().permute(0, 2, 3, 1)
            m = m.unsqueeze(dim=1)
            m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
            x_offset *= m

        x_offset = self._reshape_x_offset(x_offset, ks)
        out = self.conv(x_offset)

        return out
        
    def _get_p_n(self, N, dtype): #求
        p_n_x, p_n_y = torch.meshgrid(
            torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
            torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
        # (2N, 1)
        p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
        p_n = p_n.view(1, 2*N, 1, 1).type(dtype)

        return p_n

    def _get_p_0(self, h, w, N, dtype):
        p_0_x, p_0_y = torch.meshgrid(
            torch.arange(1, h*self.stride+1, self.stride),
            torch.arange(1, w*self.stride+1, self.stride))
        p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)

        return p_0

    def _get_p(self, offset, dtype):
        N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)

        # (1, 2N, 1, 1)
        p_n = self._get_p_n(N, dtype)
        # (1, 2N, h, w)
        p_0 = self._get_p_0(h, w, N, dtype)
        p = p_0 + p_n + offset
        return p

    def _get_x_q(self, x, q, N):
        b, h, w, _ = q.size()
        padded_w = x.size(3)
        c = x.size(1)
        # (b, c, h*w)
        x = x.contiguous().view(b, c, -1)

        # (b, h, w, N)
        index = q[..., :N]*padded_w + q[..., N:]  # offset_x*w + offset_y
        # (b, c, h*w*N)
        index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)

        x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)

        return x_offset

    @staticmethod
    def _reshape_x_offset(x_offset, ks):
        b, c, h, w, N = x_offset.size()
        x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
        x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)

        return x_offset
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

DeformableConv(可形变卷积)理论和代码分析 的相关文章

  • Android DataBinding 学习(二)

    dataBinding 二 1 在布局中使用vm变量进行资源判断 场景 点击按钮 对应的圆和按钮本身的背景颜色发生改变 不需要单独在代码中设置其背景色 可以直接在布局中镶嵌 VM public class TestVM public Obs
  • UE4_Python编写,Pycharm智能提示API

    1 按照教程配置环境 在对应的目录下会有一个unreal py 的文件 把它粘贴到对应的Python的项目目录 2 因为Pycharm 的py 文件 默认支持的智能提示是 the file size 10 5mb exceeds confi

随机推荐

  • Mk配置aar文件遇到的问题记录

    第一步 include CLEAR VARS LOCAL PREBUILT STATIC JAVA LIBRARIES demo libs demo aar 要添加的aar LOCAL AAPT FLAGS auto add overlay
  • LocalDateTime和字符串相互转换------时间转换:

    Test public void timeTest04 throws ParseException String dateTime 2022 03 21T02 29 13 732843 DateTimeFormatter dateTimeF
  • MTK Pump Express 快速充电原理分析

    1 MTK PE 1 1 原理 在讲正文之前 我们先看一个例子 对于一块电池 我们假设它的容量是6000mAh 并且标称电压是3 7V 换算成Wh 瓦时 为单位的值是22 3Wh 6000mAh 3 7V 普通的充电器输出电压电流是5V2A
  • ArcGIS 解决影像裁剪后锯齿问题

    矢量数据裁剪栅格数据的原理 个人理解 当输入矢量数据的范围完全包含或包含一个像元大小的50 及以上 裁剪时就默认把这个像元作为输出像元 反之 不输出 如下图 从而导致影像裁剪后存在锯齿问题 因此锯齿问题归根结底就是影像分辨率问题 导致结果就
  • chatgpt赋能python:Pythontomorrow:未来十年最重要的编程语言

    Python tomorrow 未来十年最重要的编程语言 Python 是一种高级 通用 解释型 面向对象的动态编程语言 自 1991 年诞生以来 Python 已成为了世界上最流行的编程语言之一 然而 Python 仍没有达到顶峰 未来的
  • C++客户端Modbus通信(TCP主站)

    本文简单介绍Qt使用外部modbus通信C 编程流程 modbus中文手册 https blog csdn net qq 23670601 article details 82155378 Qtmodbus较为方便 建议无特殊情况可以使用q
  • mysql中去除重复数据,只保留一条。

    梳理一下关于删除重复记录的逻辑 目录 前期准备 建表插入数据 1 通过group by 和count 1 gt 1找出有重复的数据 2 通过每个分组中的最小id来去重 2 1 添加主键id列 2 2 去重 2 2 1 首先找出每个分组中co
  • 数据结构:栈和队列的实现和图解二者相互实现

    文章目录 写在前面 栈 什么是栈 栈的实现 队列 什么是队列 队列的实现 用队列实现栈 用栈模拟队列 写在前面 栈和队列的实现依托的是顺序表和链表 如果对顺序表和链表不清楚是很难真正理解栈和队列的 下面为顺序表和链表的实现和图解讲解 手撕图
  • MySQL里datetime字段怎么设置默认时间

    Mysql 如何设置字段自动获取当前时间 TimeStamp和DateTime 转 MySQL datetime数据类型设置当前时间为默认值 两个方法 dateTime TimeStamp类型 建表时的设置 参考 mysql中datetim
  • 【ubuntu

    every blog every motto You can do more than you think https blog csdn net weixin 39190382 type blog 0 前言 ubuntu 22 04 安装
  • Shell脚本执行FTP操作

    一 从本地上传单个文件到FTP bin bash PUTFILE test txt ftp i v n ftp ip ftp port lt
  • android小项目之音乐播放器二

    Android应用开发 MP3音乐播放器代码实现 一 需求1 将内存卡中的MP3音乐读取出来并显示到列表当中 1 从数据库中查询所有音乐数据 保存到List集合当中 List当中存放的是Mp3Info对象 2 迭代List集合 把每一个Mp
  • LaTex的Algorithm的\caption里边的编号设置

    只需在文件头部设置 setcounter algorithm 2 就会从3开始编号 效果如下 参考
  • mysql取分组后最新的一条记录

    mysql取分组后最新的一条记录 下面两种方法 一种是先筛选 出最大和最新的时间 在连表查询 一种是先排序 然后在次分组查询 默认第一条 就是最新的一条数据了 select from t assistant article as a sel
  • python知识点总结assert利用蚁剑登录

    1 python变量 变量Python 是强类型的动态脚本语言 强类型 不允许不同类型相加 例如 整形 字符串会报类型错误 动态 不使用显示数据类型声明 且确定一个变量的类型是在第一次给它赋值的时候 脚本语言 一般是解释性语言 运行代码只需
  • Unity3d数字地球加载Arcgis数据(shp)、DEM数据(tif)、点云(las)、倾斜摄影模形(flp、osgb)

    Unity3d数字地球加载Arcgis数据 shp DEM数据 tif 点云 las 倾斜摄影模形 flp osgb QQ515716030 课程及源代码 Unity3D读取GIS文件原理解析 Unity3d数字地球加载Arcgis数据 s
  • 在Spring-Boot中引入service

    在XXXApplication的同级目录下 添加service文件夹 并在其下添加impl子文件夹 设该service用于与DAO层交互来操作student表 一 在service下添加一个interface 其名称为IStudentSer
  • Vue 3.0 模板语法

    Vue js 使用了基于 HTML 的模板语法 允许开发者声明式地将 DOM 绑定至底层组件实例的数据 所有 Vue js 的模板都是合法的 HTML 所以能被遵循规范的浏览器和 HTML 解析器解析 在底层的实现上 Vue 将模板编译成虚
  • 一文快速学会hadoop完全分布式集群搭建,很详细

    文章目录 前言 一 准备工作 二 克隆三台虚拟机并进行网络配置 克隆 虚拟机克隆引导 修改网络配置 验证 验证方式一 验证方式二 三 安装jdk和hadoop 四 ssh免密登录配置 概述 生成公钥和私钥 把公钥拷贝到三台虚拟机上面去 验证
  • DeformableConv(可形变卷积)理论和代码分析

    DeformableConv 可形变卷积 理论和代码分析 代码参考 DeformableConv代码 理论参考 bilibili视频讲解 1 DeformableConv理论分析 以上图为例进行讲解 在普通卷积中 input输入尺寸为 ba