CenterNet姿势估计decode部分代码解读

2023-11-09

代码链接:https://github.com/xingyizhou/CenterNet/blob/1085662179604dd4c2667e3159db5445a5f4ac76/src/lib/models/decode.py#L497

代码位置:src/lib/models/decode.py

代码注释

def multi_pose_decode(heat, wh, kps, reg=None, hm_hp=None, hp_offset=None, K=100):
  '''
    :param heat: keypoint heatmap 定位目标中心点的heatmap
    :param wh: object size 确定矩形宽高
    :param kps: joint locations 相对于目标中心的各关键点偏移
    :param reg: local offset 包围框的偏移补偿
    :param hm_hp: joint heatmap 一般的关键点估计heatmap
    :param hp_offset: joint offset 关键点估计的偏移
    :param K: top-K
    :return:
  '''
  batch, cat, height, width = heat.size() # cat类别数
  num_joints = kps.shape[1] // 2 # 需要估计的关键点数是通道数的一半
  # heat = torch.sigmoid(heat)
  # perform nms on heatmaps
  heat = _nms(heat) # 通过3*3最大池化找出局部最大值

  # 找到局部最大值里的top-K,返回[得分, 索引, 类别, Y值list, X值list]
  scores, inds, clses, ys, xs = _topk(heat, K=K) 
  
  # 根据top-K的索引查找并收集对应的关键点偏移量
  kps = _transpose_and_gather_feat(kps, inds) 
  kps = kps.view(batch, K, num_joints * 2)

  # 将关键点偏移量加上中心点坐标,得到相对于图像原点的关键点坐标
  kps[..., ::2] += xs.view(batch, K, 1).expand(batch, K, num_joints)
  kps[..., 1::2] += ys.view(batch, K, 1).expand(batch, K, num_joints)
  if reg is not None:
    # 如果用了关键点量化误差补偿,则解码并加到先前的结果上
    reg = _transpose_and_gather_feat(reg, inds)
    reg = reg.view(batch, K, 2)
    xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
    ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
  else:
    # 如果没有用该分支,则都加0.5,减少量化误差
    xs = xs.view(batch, K, 1) + 0.5
    ys = ys.view(batch, K, 1) + 0.5
  # 查找对应的包围框宽高信息
  wh = _transpose_and_gather_feat(wh, inds)
  wh = wh.view(batch, K, 2)
  clses  = clses.view(batch, K, 1).float()
  scores = scores.view(batch, K, 1)

  # 根据中心点坐标和包围框宽高计算xmin, ymin, xmax, ymax
  bboxes = torch.cat([xs - wh[..., 0:1] / 2, 
                      ys - wh[..., 1:2] / 2,
                      xs + wh[..., 0:1] / 2, 
                      ys + wh[..., 1:2] / 2], dim=2)

  # 一般的关键点估计分支
  if hm_hp is not None:
      hm_hp = _nms(hm_hp) # 通过3*3最大池化找极值
      thresh = 0.1
      # kps原shape[b x K x 2N] => [b * N * K * 2]
      kps = kps.view(batch, K, num_joints, 2).permute(
          0, 2, 1, 3).contiguous() # b x J x K x 2
      # 添加一维[b * N * K * K * 2]
      reg_kps = kps.unsqueeze(3).expand(batch, num_joints, K, K, 2)
      # 对每个channel取top-K,即取到各种类型关键点的top-K
      hm_score, hm_inds, hm_ys, hm_xs = _topk_channel(hm_hp, K=K) # b x J x K
      # 如果起用了关键点分支的偏移,则对关键点坐标进行校正
      if hp_offset is not None:
          hp_offset = _transpose_and_gather_feat(
              hp_offset, hm_inds.view(batch, -1))
          hp_offset = hp_offset.view(batch, num_joints, K, 2)
          hm_xs = hm_xs + hp_offset[:, :, :, 0]
          hm_ys = hm_ys + hp_offset[:, :, :, 1]
      else:
          hm_xs = hm_xs + 0.5
          hm_ys = hm_ys + 0.5

      # 去掉小于阈值的
      mask = (hm_score > thresh).float()
      hm_score = (1 - mask) * -1 + mask * hm_score
      hm_ys = (1 - mask) * (-10000) + mask * hm_ys
      hm_xs = (1 - mask) * (-10000) + mask * hm_xs
      # 使用一般的关键点估计网络预测出的关键点
      hm_kps = torch.stack([hm_xs, hm_ys], dim=-1).unsqueeze(
          2).expand(batch, num_joints, K, K, 2)
      # 全排列计算距离
      dist = (((reg_kps - hm_kps) ** 2).sum(dim=4) ** 0.5)
      min_dist, min_ind = dist.min(dim=3) # b x J x K
      hm_score = hm_score.gather(2, min_ind).unsqueeze(-1) # b x J x K x 1
      min_dist = min_dist.unsqueeze(-1)
      min_ind = min_ind.view(batch, num_joints, K, 1, 1).expand(
          batch, num_joints, K, 1, 2)
      hm_kps = hm_kps.gather(3, min_ind)
      hm_kps = hm_kps.view(batch, num_joints, K, 2)
      l = bboxes[:, :, 0].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
      t = bboxes[:, :, 1].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
      r = bboxes[:, :, 2].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
      b = bboxes[:, :, 3].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
      # 根据以下逻辑挑选在一般关键点估计分支输出上最终可以作为关键点refine结果的点(下面代码是剔除的逻辑):
      # 1. 落在包围框内
      # 2. 得分高于阈值
      # 3. 与基于中心点回归出的对应关键点距离不能超过包围框尺寸的.3倍
      mask = (hm_kps[..., 0:1] < l) + (hm_kps[..., 0:1] > r) + \
             (hm_kps[..., 1:2] < t) + (hm_kps[..., 1:2] > b) + \
             (hm_score < thresh) + (min_dist > (torch.max(b - t, r - l) * 0.3))
      mask = (mask > 0).float().expand(batch, num_joints, K, 2)
      # 使用匹配成功的refine关键点 + 其余的基于中心点回归出的关键点
      kps = (1 - mask) * hm_kps + mask * kps
      kps = kps.permute(0, 2, 1, 3).contiguous().view(
          batch, K, num_joints * 2)
  detections = torch.cat([bboxes, scores, kps, clses], dim=2)
    
  return detections

思考

 关键点的refine思想,和center point的refine思想相同。都是通过预测一个尺寸同对应heatmap的通道数为2的feature map,来得到heatmap上不同位置处,对应的X/Y方向上的offset

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

CenterNet姿势估计decode部分代码解读 的相关文章

随机推荐

  • macOS 视频格式转换:ffmpeg + shell 脚本【最优方案】【免费 + 高效】

    效果完美 开始转换 成功输出 ffmpeg 下载 github 开源下载 下载地址 https ffmpeg org download html shell 脚本 你的用户名 替换成你得自己的对应路劲 比如你下载的 ffmpeg 躲在路劲
  • windows的磁盘操作之七——获取当前所有的物理磁盘号

    有了前几节的基础后 本节给出一个更复杂但却非常实用的例子 很多情况下 我们想知道当前系统下安装了多少块磁盘 他们的物理驱动器号都是多少 每一块磁盘上有多少个分区 分区号怎么分布 每个分区大小是多少 这就类似于我们打开windows 的磁盘管
  • c++的工程文件的编译顺序

    以前一直以为 vs在编译c 文件时候是从头文件开始编译的 而每个头文件对应的源文件只是头文件定义中的一些实现而已 源文件不参与编译 今天经过同学指点并实践之后才发现 其实不是这样的 从中受益颇多 c 编译的时候实际上只编译源文件 而不编译头
  • 416. 分割等和子集

    题目描述 给你一个 只包含正整数 的 非空 数组 nums 请你判断是否可以将这个数组分割成两个子集 使得两个子集的元素和相等 示例 1 输入 nums 1 5 11 5 输出 true 解释 数组可以分割成 1 5 5 和 11 示例 2
  • nginx实战总结-request_time和upstream_response_time详解

    一 前言 这个主要是日志模块的延伸 这两个参数 在实战中非常重要 因此提出来单独说 二 图解 从上图中得出以下结论 打印日志是在最后一个步骤 也就是说整套请求完毕后 进行打印 请求的整套时间线 1 客户端 request gt nginx
  • 最大流解决医生排班问题

    目录 问题描述 场景建模 Ford Fulkerson方法 Edmonds karp算法 Dinic算法 问题描述 一个医院有n名医生 现有k个公共假期需要安排医生值班 每一个公共假期由若干天 假日 组成 第j个假期包含的假日用 Dj表示
  • Python接口自动化测试之文件上传

    在接口测试中 经常会涉及到文件上传 文件上传一般包含的文件是图片 视频以及如csv excel 记事本等文件 它的请求头中Content Type对应的value值是multipart form data 这里依据实际的案例来说明文件上传的
  • Makefile 神奇:驾驭编译的力量

    一 make和Makefile 当谈到 make 和 Makefile 时 通常是指构建工具 make 和用于描述编译和构建过程的文本文件 Makefile make 是一个在类Unix系统中广泛使用的构建工具 它基于文件的时间戳比较 只编
  • 【Vue】生命周期回调函数

    生命周期 又名 生命周期回调函数 生命周期函数 生命周期钩子 程序员间沟通常称生命周期钩子 是什么 Vue在关键时刻帮我们调用的一些特殊名称的函数 生命周期函数的名字不可更改 但是函数的具体内容是程序员根据需求编写的 生命周期函数中的 th
  • java中的Socket编程

    基于Socket的java网络编程 网络上的两个程序通过一个双向的通讯连接实现数据的交换 这个双向链路的一端成为一个socket Socket通常用来实现客户方和服务方的连接 Socket是TCP IP协议的一个十分流行的编程界面 一个so
  • window环境下 —Apache 2.4下载、安装配置与卸载

    一 Apache的下载 1 下载地址 https www apachehaus com cgi bin download plx 2 安装Apache 解压后打开conf文件夹下httpd conf文件 修改Apache目录地址 Defin
  • python螺旋矩阵

    Python 螺旋矩阵 给你一个正整数 n 生成一个包含 1 到 n2 所有元素 且元素按顺时针顺序螺旋排列的 n x n 正方形矩阵 matrix class Solution def generateMatrix self n int
  • system.ComponentModel.Win32Exception (0x80004005): 目录名无效。 解决方法

    system ComponentModel Win32Exception 0x80004005 目录名无效 解决方法 参考文章 1 system ComponentModel Win32Exception 0x80004005 目录名无效
  • Neural Filters用不了怎么办?推荐uminar AI for Mac人工智能照片编辑软件

    Luminar AI 1 3 0 for Mac是macOS第一款完全人工智能的照片编辑软件 摄影爱好者和专业摄影师 设计师必备的后期软件 Luminar AI 可以作为独立的照片编辑软件或作为PS LRC插件使用 功能强大媲美PS的神经滤
  • 【Windows】VScode终端添加GitBash,终端直接调用git

    1 打开VScode 文件 gt gt 首选项 gt gt 设置 搜索 shell windows 点击settings json编辑 把下面的语句复制进去 terminal integrated profiles windows Powe
  • 台达b3伺服参数设置方法_台达伺服驱动器参数设置一览表

    台达伺服驱动器参数设置一览表 2020 12 23 台达伺服驱动器的参数设置分为八大群组 从P0到P7 参数群组定义如下 群组 0 监控参数 例 P0 xx 群组 1 基本参数 例 P1 xx 群组 2 扩展参数 例 P2 xx 群组 3
  • oracle 表 xml,详细分析Oracle XML数据

    在向大家详细介绍Oracle XML数据之前 首先让大家了解下Oracle 11g 然后全面介绍Oracle XML数据 在Oracle 11g可以使用CLOB及二进制两种方式保存XML信息 灵活性很高 Oracle 11g还支持针对XML
  • ElementUI/ElementPlus+笔记

    如何修改特定文件下使用的Element组件样式 在哪修改样式 在修改element样式时 最好在scoped中修改避免全局污染 如果在scoped中修改样式不生效就在全局中修改 但是在要修改的样式外面套一层class 避免修改了所有页面 使
  • EasyExcel读写Excel

    转载 侵删 原文链接 https mp weixin qq com s T xBuoYgj1NuM7 yHe084Q 最近读者小 H 在知识星球中给阿粉发来私信 阿粉 最近我在负责公司报表平台开发 需要导出报表到 Excel 中 每次使用
  • CenterNet姿势估计decode部分代码解读

    代码链接 https github com xingyizhou CenterNet blob 1085662179604dd4c2667e3159db5445a5f4ac76 src lib models decode py L497 代