ECCV 2022

2023-10-31

ECCV 2022 | Learning Implicit Feature Alignment Function for Semantic Segmentation概述与代码分析

在这里插入图片描述

主要工作

在这里插入图片描述

在这里插入图片描述

基于隐神经表示设计了一种隐式特征对齐函数,来替换现有的基于插值的不同分辨率特征对齐方案。可以更加方便和高效的对齐多个不同分辨率的特征。

原始的隐式特征函数:

在这里插入图片描述
不考虑专业术语。直观来讲,隐式特征函数本身是基于原始特征和目标特征之间的坐标关系,构建了一个从原始特征到目标特征映射变换。其中的变换关系可以通过神经网络学习和建模。

这一过程需要提供以下三种信息:

  • 已有的原始特征 z z z
  • 原始特征对应的连续的(归一化)网格坐标 x x x
  • 我们想要生成的目标特征/预测对应的连续的(归一化)网格坐标 x q x_q xq

注意这里强调了归一化坐标。这一方法核心的一个假定是网格坐标系本身是对齐的,可能只是单位刻度上有差异。

通过这些信息,我们可以利用坐标之间的相对关系,从原始特征变换得到目标特征/预测。

需要注意的是,这一变换过程中,主要关注坐标系中与目标位置最邻近的原始特征点。

在此基础上,作者们引入了相对位置编码获得了更好的对齐效果:

在这里插入图片描述
在这里插入图片描述

通过同时集成多个不同层级的特征来实现对于最终预测的检索和计算:

在这里插入图片描述

实验结果

在这里插入图片描述

在这里插入图片描述

核心代码解析

  • https://github.com/hzhupku/IFA/blob/main/pyseg/models/ifa_utils.py
  • https://github.com/hzhupku/IFA/blob/main/pyseg/models/ifa.py
  • https://github.com/hzhupku/IFA/blob/main/pyseg/models/fpn_ifa.py
import torch
import torch.nn as nn
import torch.nn.functional as F


def make_coord(hw, flatten=True):
    """构建网格坐标系,原点位于各轴有效范围的中心点。

    使用的网格坐标系的三个参考点:网格区域的左右边界为-1和1或者ranges的两个值,正中心为0。

    返回的网格坐标为 [N,[...,]len(hw)],其中最后一维表示具体的坐标,坐标顺序与hw中轴的顺序一致。
    """
    start_idx, end_idx = -1, 1

    axes_grid_centers = []
    for i, n in enumerate(hw):
        # 单一轴向的半个网格的宽度
        width_of_half_grid = (end_idx - start_idx) / (2 * n)

        # 这里计算的是各个方形网格区域的中心点坐标。
        start_grid_center = start_idx + width_of_half_grid
        grid_centers = (
            start_grid_center + (2 * width_of_half_grid) * torch.arange(n).float()
        )
        # 使用linspace替换会导致精度无法对齐
        # end_grid_center = end_idx - width_of_half_grid
        # grid_centers = torch.linspace(start_grid_center, end_grid_center, steps=n)
        axes_grid_centers.append(grid_centers)

    paired_grid_centers = torch.stack(
        torch.meshgrid(*axes_grid_centers, indexing="ij"), dim=-1
    )
    if flatten:
        paired_grid_centers = paired_grid_centers.reshape(
            -1, paired_grid_centers.shape[-1]
        )
    return paired_grid_centers


def ifa_feat_ann(src, tgt_hw, stride=1, local_ensemble=False):
    bs, src_h, src_w = src.shape[0], src.shape[-2], src.shape[-1]
    tgt_h, tgt_w = tgt_hw

    coord_tgt_hw = make_coord((tgt_h, tgt_w)).to(device=src.device)
    # hw,[tgt_h_id,tgt_w_id] =(repeat)=> bs,hw,[tgt_h_id,tgt_w_id] in (-1,1)
    coord_tgt_hw = coord_tgt_hw.unsqueeze(0).expand(bs, *coord_tgt_hw.shape)
    # 使用后可以与原始实现对齐,但是实际属于冗余操作
    # coord_tgt_hw = (coord_tgt_hw + 1) / 2 * 2 - 1

    coord_src_hw = make_coord((src_h, src_w), flatten=False).to(device=src.device)
    # src_h,src_w,[src_h_id,src_w_id]
    # => [src_h_id,src_w_id],src_h,src_w
    coord_src_hw = coord_src_hw.permute(2, 0, 1)
    # =(repeat)=> bs,[src_h_id,src_w_id],src_h,src_w in (-1,1)
    coord_src_hw = coord_src_hw.unsqueeze(0).expand(bs, 2, src_h, src_w)

    if local_ensemble:
        # 利用局部ensemble来缓解基于索引的预测方式导致的预测不连续的问题
        # 直接利用目标位置与周围四个隐编码位置之间的包围矩形面积来加权组合获得的四个预测,
        # 从而平滑索引改变时带来的预测变化。
        # 这一加权平滑的方式基本是沿用了双线性插值的思路。
        tgt_x_shifts = [-1, 1]
        tgt_y_shifts = [-1, 1]
        eps_shift = 1e-6

        rel_coord_hws = []
        src2tgt_feats = []
        areas = []
    else:
        tgt_x_shifts, tgt_y_shifts, eps_shift = [0], [0], 0

    # tgt网格坐标系下的相对步长
    tgt_x_stride = stride / tgt_w
    tgt_y_stride = stride / tgt_h

    for tgt_x_shift in tgt_x_shifts:
        for tgt_y_shift in tgt_y_shifts:
            # bs,hw,[tgt_w_id,tgt_h_id] in (-1,1)
            coord_tgt_xy = coord_tgt_hw.flip(-1).clone()
            # 在考虑局部ensemble的时候,这里对tgt坐标进行一个单位的相对偏移后再对src进行查询与映射
            coord_tgt_xy[:, :, 0] += tgt_x_shift * tgt_x_stride + eps_shift
            coord_tgt_xy[:, :, 1] += tgt_y_shift * tgt_y_stride + eps_shift
            coord_tgt_xy.clamp_(-1 + 1e-6, 1 - 1e-6)
            # bs,1,hw,[tgt_w_id,tgt_h_id]
            coord_tgt_xy = coord_tgt_xy.unsqueeze(1)

            # 使用tgt网格坐标对src特征网格坐标调整
            # 采样 bs,[src_h_id,src_w_id],src_h,src_w 到 bs,[src_h_id',src_w_id'],1,hw
            coord_src2tgt_hw = F.grid_sample(
                coord_src_hw, coord_tgt_xy, mode="nearest", align_corners=False
            )
            # bs,hw,[src_h_id',src_w_id']
            coord_src2tgt_hw = coord_src2tgt_hw[:, :, 0, :].permute(0, 2, 1)

            # 与nearest latent code,即这里的src,相对坐标偏移
            rel_coord_hw = coord_tgt_hw - coord_src2tgt_hw
            rel_coord_hw[:, :, 0] *= src_h  # src.shape[-2]
            rel_coord_hw[:, :, 1] *= src_w  # src.shape[-1]

            # 使用目标网格坐标对输入特征重新采样
            # bs,c,src_h,src_w => bs,c,1,tgt_h*tgt_w => bs,tgt_h*tgt_w,c
            src2tgt_feat = F.grid_sample(
                src, coord_tgt_xy, mode="nearest", align_corners=False
            )
            src2tgt_feat = src2tgt_feat[:, :, 0, :].permute(0, 2, 1)

            if local_ensemble:
                rel_coord_hws.append(rel_coord_hw)
                src2tgt_feats.append(src2tgt_feat)
                # 在局部ensemble的时候,需要统计tgt与周围四个src位置之间矩形的面积,用来加权平均从而平滑结果
                # 而面积的计算正好是相对坐标乘积的绝对值
                area = torch.abs(rel_coord_hw[:, :, 0] * rel_coord_hw[:, :, 1])
                areas.append(area + 1e-9)

    if not local_ensemble:
        return rel_coord_hw, src2tgt_feat
    else:
        return rel_coord_hws, src2tgt_feats, areas


class ifa_simfpn(nn.Module):
    def __init__(...):
        super().__init__()
        if learn_pe:
            self.pos1 = PositionEmbeddingLearned(self.pos_dim // 2)
            self.pos2 = PositionEmbeddingLearned(self.pos_dim // 2)
            self.pos3 = PositionEmbeddingLearned(self.pos_dim // 2)
            self.pos4 = PositionEmbeddingLearned(self.pos_dim // 2)
        if ultra_pe:
            self.pos1 = SpatialEncoding(2, self.pos_dim, require_grad=require_grad)
            self.pos2 = SpatialEncoding(2, self.pos_dim, require_grad=require_grad)
            self.pos3 = SpatialEncoding(2, self.pos_dim, require_grad=require_grad)
            self.pos4 = SpatialEncoding(2, self.pos_dim, require_grad=require_grad)
            self.pos_dim += 2

        in_dim = 4 * (256 + self.pos_dim)
        if unfold:
            in_dim = 4 * (256 * 9 + self.pos_dim)

        self.imnet = ...  # in_dim -> num_classes

    def forward(self, x, size, level=0, after_cat=False):
        h, w = size
        if after_cat:
            return self.imnet(x).reshape(x.shape[0], -1, h, w)

        # Feature unfolding: 为了丰富隐码包含的信息,对特征中3×3相邻隐码合并
        if self.unfold:
            x = F.unfold(x, 3, padding=1).reshape(
                x.shape[0], x.shape[1] * 9, x.shape[2], x.shape[3]
            )

        if not self.local:
            rel_coord_hw, src2tgt_feat = ifa_feat_ann(src=x, tgt_hw=[h, w])

            if self.ultra_pe or self.learn_pe:
                rel_coord_hw = eval("self.pos" + str(level))(rel_coord_hw)

            x = torch.cat([rel_coord_hw, src2tgt_feat], dim=-1)
        else:
            rel_coord_hws, src2tgt_feats, areas = ifa_feat_ann(
                src=x,
                tgt_hw=[h, w],
                stride=self.stride,
                local_ensemble=True,
            )

            contexts = []
            for rel_coord_hw, src2tgt_feat, area in zip(
                rel_coord_hws, src2tgt_feats, areas
            ):
                if self.ultra_pe or self.learn_pe:
                    rel_coord_hw = eval("self.pos" + str(level))(rel_coord_hw)
                contexts.append(torch.cat([rel_coord_hw, src2tgt_feat], dim=-1))

            # 这里将对角区域的面积进行了交换。0号与3号,1号与2号
            # 整体的特征组合方式与双线性插值形式一致
            # 关于双线性插值可见 https://blog.csdn.net/qq_58664081/article/details/129079354
            areas[0], areas[3] = areas[3], areas[0]
            areas[1], areas[2] = areas[2], areas[1]
            total_area = torch.stack(areas).sum(dim=0)

            for cxt, area in zip(contexts, areas):
                x = cxt * ((area / total_area).unsqueeze(-1))
        return x


class fpn_ifa(nn.Module):
    def __init__(...):
        super().__init__()
        ...

        self.ifa = ifa_simfpn(
            ultra_pe=ultra_pe,
            pos_dim=pos_dim,
            num_classes=num_classes,
            local=local,
            unfold=unfold,
            stride=stride,
            learn_pe=learn_pe,
            require_grad=require_grad,
            num_layer=num_layer,
        )

    def forward(self, x):
        x1, x2, x3, x4 = x
        aspp_out = ...

        context = []
        h, w = x1.shape[-2], x1.shape[-1]
        target_feat = [x1, x2, x3, aspp_out]

        for i, feat in enumerate(target_feat):
            context.append(self.ifa(feat, size=[h, w], level=i + 1))
        context = torch.cat(context, dim=-1).permute(0, 2, 1)  # B,HW,C -> B,C,HW
        return self.ifa(context, size=[h, w], after_cat=True)

这里代码的设计应当是借鉴自图像超分辨算法LIIF中的设计,代码基本一致https://github.com/yinboc/liif/blob/main/models/liif.py。

本文保留了LIIF中的Local Ensemble和Feature Unfolding的设计,但是不同之处主要有两点:

  • 相对位置信息的使用不同于LIIF中直接将其作为imnet的输入的部分通道,这里使用了位置编码的方式进行处理。
  • 没有使用LIIF中的Cell Decoding。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

ECCV 2022 的相关文章

随机推荐

  • 【Unity问题&错误】list问题

    error CS0305 Using the generic type System Collections Generic List
  • ctfshow web入门 代码审计

    文章目录 web301 web302 web303 web304 web305 web306 web307 web308 web309 web310 web301 下载源码后在checklogin php发现问题代码
  • ChatGPT未来会拥有自我情感和思维吗?

    ChatGPT是一种基于人工智能的聊天机器人 它可以模拟人类的对话 并且可以回答各种问题 目前 ChatGPT已经非常先进 但是它是否会拥有自我情感和思维呢 首先 我们需要明确一点 ChatGPT是一种基于机器学习的算法 它的行为是由程序员
  • 记一次阿里巴巴电话面试题

    前几天投了阿里巴巴校招简历 今天晚上突然来了阿里的面试电话 有点紧张又有点激动 面试的问题问的挺全面 但是有些准备不足 因此回答的并不算太满意 现在整理一下分享给大家 希望进阿里的小伙伴可以来踩踩 1 自我介绍 打电话时我刚吃完饭 上来就介
  • AndroidManifest.xml作用

    今天在看到一篇博客是对于博主对于AndroidManifest xml文件的作用的理解深感赞同 AndroidManifest xml文件就是在安装的时候用来给PackageManagerService进行解析 分析出这个APK的packa
  • 【NeurIPS 2021】ViT 中增强的 Shortcut Connection:Augmented Shortcuts for Vision Transformers

    Augmented Shortcuts for Vision Transformers 论文地址 主要工作 方案简介 基本定义 具体实现 Augmented Shortcuts Efficient Implementation via Ci
  • 权限维持篇---Windows权限维持--隐藏篇

    权限维持篇 Windows权限维持 隐藏篇 文章目录 权限维持篇 Windows权限维持 隐藏篇 前言 一 隐藏文件 二 隐藏账号 三 端口复用 四 进程注入 五 结束 六 我的公众号 前言 攻击者在获取服务器权限后 通常会用一些后门来维持
  • AD 原理图网络未连上,设置DRC报错

    AD原理图整理时 碰到一个网络没有连接 但是DRC检查没有提示有异常的情况 如下图 R7H右端并没有连上 然而原理图检查居然没有问题 在导入PCB时才报错 因为原理图DRC没有报错 想要找到问题 需要蛮多时间 细思极恐 AD其实是可以设置检
  • 数据结构-顺序栈的基本操作的实现(含全部代码)

    主要操作函数如下 InitStack SqStack s 参数 顺序栈s 功能 初始化 时间复杂度O 1 Push SqStack s SElemType e 参数 顺序栈s 元素e 功能 将e入栈 时间复杂度 O 1 Pop SqStac
  • 【黑马程序员】面向对象(五) 第九天

    android培训 java培训 java学习型技术博客 期待与您交流 知识点 异常处理能够使一个方法给它的调用者抛出一个异常 异常发生在一个方法的执行过程中 RuntimeException和Error都是免检异常 其它所有异常都是必检的
  • vite --- 搭建开发环境

    目录 下载安装和初始化VSCode 安装Node js yarn 使用 pnpm 安装与使用 搭建第一个Vite项目 使用 PNPM创建项目 项目目录解读 下载安装和初始化VSCode 1 访问网站 Visual Studio Code C
  • editplus配置python环境 和 php环境

    editplus配置python环境 和 php环境 使用editplus这么久 才知道是可以配置python环境 和 php环境 想来真丢人 这就是自学的痛苦之处 许多时如果不是自己突然想到 只会永远在黑暗中摸索 editplus配置py
  • KNN与CNN

    KNN与CNN相关 KNN K Nearest Neighbor 最邻近分类算法 就是k个最近的邻居的意思 说的是每个样本都可以用它最接近的k个邻居来代表 KNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个
  • windows xp 驱动开发(三)DDK与WDK WDM的区别

    转自 http www cnblogs com hyddd archive 2009 03 15 1412684 html 最近尝试去了解WINDOWS下的驱动开发 现在总结一下最近看到的资料 1 首先 先从基础的东西说起 开发WINDOW
  • 第十四章 AlibabaCloud微服务下的链路追踪系统

    1 微服务架构下的排查问题复杂性概述 两个常 的问题 微服务调 链路出现了问题怎么快速排查 微服务调 链路耗时 怎么定位是哪个服务 链路追踪系统 分布式应 架构虽然满 了应 横向扩展的需求 但是运维和诊断的过程变得越来越复杂 例如会遇到接
  • linux 拷贝文件夹并覆盖另一个文件夹 cp指令

    参考 参考 https m runoob com linux linux comm cp html Linux cp 英文全拼 copy file 命令主要用于复制文件或目录 语法 cp options source dest 或 cp o
  • BoolQueryBuilder 和 wildcardQuery withFilter 查询

    一 BoolQueryBuilder查询说明 BoolQueryBuilder qb QueryBuilders boolQuery 1 返回的文档必须满足must子句的条件 并且参与计算分值 qb must QueryBuilder qu
  • 腾讯云数据库TDSQL:分布式数据库,你真的了解吗?

    分布式数据库进入人们的视野已经很久了 相对于传统的集中式数据库 分布式数据库在高性能 高可用 平滑拓展 高可靠 低成本等许多方面具有优势 但时至今日 关于分布式数据库 似乎一直缺少足够权威和客观的解读 现在 国家白皮书来了 为了明确分布式数
  • Vue SSR(vue服务端渲染)

    SSR的应用场景 1 SEO需求 SEO Search Engine Optimization 搜索引擎优化 是一种利用搜索引擎规则 提高网站在搜索引擎内自然排名的技术 通常这需要页面内容在页面加载完成时便已经存在 前后端分离的纯前端项目
  • ECCV 2022

    ECCV 2022 Learning Implicit Feature Alignment Function for Semantic Segmentation概述与代码分析 论文 https arxiv org abs 2206 0865