ssd.pytorch源码分析(三)— 非极大值抑制NMS

2023-05-16

NMS源码
SSD论文链接

NMS介绍

吴恩达对于NMS(非极大值抑制)的介绍:
在这里插入图片描述
说白了,NMS的作用就是去掉目标检测任务重复的检测框。 例如,一个目标有多个选择框,现在要去掉多余的选择框。怎么做呢?循环执行步骤1和2, 直到只剩下一个框:

  • 1、选出置信度p_c最高的框;
  • 2、去掉和这个框IOU>0.7的框。

相关函数

一、torch.clamp( )

torch.clamp(input, min, max, out=None) → Tensor

将输入input张量每个元素夹紧到区间 [min,max],并返回结果到一个新张量。
类似于numpy中的np.clip
操作定义如下:

	  | min, if x_i < min
y_i = | x_i, if min <= x_i <= max
      | max, if x_i > max

参数:

  • input (Tensor) – 输入张量
  • min (Number) – 限制范围下限
  • max (Number) – 限制范围上限
  • out (Tensor, optional) – 输出张量

例子:

>>> a = torch.randn(4)
>>> a
 1.3869
 0.3912
-0.8634
-0.5468
[torch.FloatTensor of size 4]
>>> torch.clamp(a, min=-0.5, max=0.5)
 0.5000
 0.3912
-0.5000
-0.5000
[torch.FloatTensor of size 4]

二、torch.index_select()

torch.index_select(input, dim, index, out=None) → Tensor

沿着指定维度对输入进行切片。

参数:

  • input (Tensor) – 输入张量
  • dim (int) – 索引的轴
  • index (LongTensor) – 包含索引下标的一维张量
  • out (Tensor, optional) – 目标张量

例子:

>>> x = torch.randn(3, 4)
>>> x

 1.2045  2.4084  0.4001  1.1372
 0.5596  1.5677  0.6219 -0.7954
 1.3635 -1.2313 -0.5414 -1.8478
[torch.FloatTensor of size 3x4]

>>> indices = torch.LongTensor([0, 2])
>>> torch.index_select(x, 0, indices)

 1.2045  2.4084  0.4001  1.1372
 1.3635 -1.2313 -0.5414 -1.8478
[torch.FloatTensor of size 2x4]

>>> torch.index_select(x, 1, indices)

 1.2045  0.4001
 0.5596  0.6219
 1.3635 -0.5414
[torch.FloatTensor of size 3x2]

注意,index_select函数中的参数index表示了有哪些索引值是需要保留的。

三、 torch.numel()

torch.numel(input)->int 

返回input 张量中的元素个数。

复现代码

以下为ssd.pytorch中NMS(实际上在任何anchor based的目标检测框架中都适用)。其中:

  • 为了减少计算量,作者仅选取置信度前top_k=200个框;
  • 代码中包含了IOU的计算。关于IOU计算推荐阅读这篇文章;
def nms(boxes, scores, overlap=0.7, top_k=200):
    """
    输入:
        boxes: 存储一个图片的所有预测框。[num_positive,4].
        scores:置信度。如果为多分类则需要将nms函数套在一个循环内。[num_positive].
        overlap: nms抑制时iou的阈值.
        top_k: 先选取置信度前top_k个框再进行nms.
    返回:
        nms后剩余预测框的索引.
    """
    
    keep = scores.new(scores.size(0)).zero_().long() 
    # 保存留下来的box的索引 [num_positive]
    # 函数new(): 构建一个有相同数据类型的tensor 
    
	#如果输入box为空则返回空Tensor
    if boxes.numel() == 0: 
        return keep
        
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    area = torch.mul(x2 - x1, y2 - y1) #并行化计算所有框的面积
    v, idx = scores.sort(0)  # 升序排序
    idx = idx[-top_k:]  # 前top-k的索引,从小到大
    xx1 = boxes.new()
    yy1 = boxes.new()
    xx2 = boxes.new()
    yy2 = boxes.new()
    w = boxes.new()
    h = boxes.new()

    count = 0
    while idx.numel() > 0:
        i = idx[-1]  # 目前最大score对应的索引
        keep[count] = i #存储在keep中
        count += 1
        if idx.size(0) == 1: #跳出循环条件:box被筛选完了
            break
        idx = idx[:-1]  # 去掉最后一个
        
        #剩下boxes的信息存储在xx,yy中
        torch.index_select(x1, 0, idx, out=xx1)
        torch.index_select(y1, 0, idx, out=yy1)
        torch.index_select(x2, 0, idx, out=xx2)
        torch.index_select(y2, 0, idx, out=yy2)
        
        # 计算当前最大置信框与其他剩余框的交集,不知道clamp的同学确实容易被误导
        xx1 = torch.clamp(xx1, min=x1[i])  #max(x1,xx1)
        yy1 = torch.clamp(yy1, min=y1[i])  #max(y1,yy1)
        xx2 = torch.clamp(xx2, max=x2[i])  #min(x2,xx2)
        yy2 = torch.clamp(yy2, max=y2[i])  #min(y2,yy2)
        w.resize_as_(xx2)
        h.resize_as_(yy2)
        w = xx2 - xx1 #w=min(x2,xx2)−max(x1,xx1)
        h = yy2 - yy1 #h=min(y2,yy2)−max(y1,yy1)
        w = torch.clamp(w, min=0.0) #max(w,0)
        h = torch.clamp(h, min=0.0) #max(h,0)
        inter = w*h
        
		#计算当前最大置信框与其他剩余框的IOU
        # IoU = i / (area(a) + area(b) - i)
        rem_areas = torch.index_select(area, 0, idx)  # 剩余的框的面积
        union = rem_areas + area[i]- inter #并集
        IoU = inter/union  # 计算iou
        
        # 选出IoU <= overlap的boxes(注意le函数的使用)
        idx = idx[IoU.le(overlap)]
    return keep,          count
    	   #[num_remain], num_remain
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

ssd.pytorch源码分析(三)— 非极大值抑制NMS 的相关文章

随机推荐

  • 云计算的部署

    一 云计算的服务和交付模式 基础设施即服务 xff08 Iaas xff09 平台即服务 xff08 Paas xff09 软件即服务 xff08 Saas xff09 衍生出 xff1a 存储即服务 数据库即服务 安全即服务 通信即服务
  • MapReduce的数据流程、执行流程

    MapReduce的大体流程是这样的 xff0c 如图所示 xff1a 由图片可以看到mapreduce执行下来主要包含这样几个步骤 1 首先对输入数据源进行切片 2 master调度worker执行map任务 3 worker读取输入源片
  • 免费下载中国知网、万方学术论文的几种方法(福利合集)

    在国内 xff0c 中国知网收录了最多的期刊论文和硕博士论文 无论学霸学渣 xff0c 都得上去下载论文 如果你的学校在知网购买了相应的下载版权 xff0c 那恭喜你 xff0c 你通过校园网就能免费下载了 但一旦你回了家 xff0c 或学
  • 使用apt离线安装deb包

    文章目录 apt 下载的deb路径阻止apt自动删除缓存文件的方法只下载不安装的方法离线安装deb包离线安装gcc1 下载依赖2 打包下载的deb文件 xff0c 上传到没有外网连接的服务器3 安装deb包 apt 下载的deb路径 默认存
  • haar分类

    今天说一说haar分类算法 首先介绍haar like特征 haar like的特征有边缘特征 线性特征 中心特征和对角线特征 我们使用特征模板来表示特征的计算 xff0c 如图所示 xff1a 这些特征分别对应着不同的矩阵以便于进行计算
  • POI window excel 打开提示部分内容有问题, 是否尝试尽量恢复

    问题如下 window excel 打开报错如下 但是WPS打开正常 问题在于 window excel 冻结窗口只能设置一行 WPS可以设置多行 设置冻结窗口如下 冻结第一行 sheet createFreezePane 0 1 0 1
  • 解决从数据库中取出json数据有转义符

    不处理从数据库取出数据如下 String s1 61 34 34 MsgId 34 1 34 TotalCount 34 10 34 FilterCount 34 8 34 SentCount 34 7 34 ErrorCount 34 0
  • 查询数据报错 com.mysql.cj.exceptions.DataConversionException

    com mysql cj exceptions DataConversionException Caused by java sql SQLDataException Cannot determine value type from str
  • 微信调用接口报错:"errcode":45009,"errmsg":"reach max api daily quota limit hints:

    api请求次数达到最大上限 每个帐号每月共10次清零操作机会 xff0c 清零生效一次即用掉一次机会 xff08 10次包括了平台上的清零和调用接口API的清零 xff09 https developers weixin qq com do
  • @FeignClient注解 中属性 contextId使用

    64 FeignClient注解 中属性 contextId 比如我们有个user服务 xff0c 但user服务中有很多个接口 xff0c 我们不想将所有的调用接口都定义在一个类中 xff0c 比如 xff1a Client span c
  • toString和toJSONString的区别

    Map span class token generics function span class token punctuation lt span String span class token punctuation span Int
  • Neutron运营商网络和租户网络详解

    由租户创建并且管理的网络 xff0c Neutron称之为租户网络 但是Openstack不是万能的 xff0c Neutron也不是万能的 还有很多网络不在Neutron管理范围内 xff08 Neutron称之为外部网络 xff09 有
  • mysql in查询太慢, 使用join优化

    mysql中查询 in 参数太多 导致查询很慢 使用join优化 在实例中in查询话费2s 优化后0 4s span class token keyword SELECT span span class token operator spa
  • Springboot 多数据源事务,切换数据源+事务

    项目有多个数据源 根据配置文件配置的连接数来自动生成多数据源配置 并且使用 aop切换数据源 使用的是 AbstractRoutingDataSource 重写 determineCurrentLookupKey 方法 在切换数据源之前 6
  • Redisson自定义序列化方式

    redissonClient span class token punctuation span span class token function getBucket span span class token punctuation s
  • 方法区使用举例

    span class token keyword public span span class token keyword class span span class token class name MethodAreaDemo span
  • mysql动态字段行转列

    动态行转列 table schema id name s 001 是否吃饭了 s 002 你的汽车品牌 table schema value id user id schema id schema value span class toke
  • freertos学习02-队列 stream buffer message buffer

    1 freertos数据传递简介 在freertos中 xff0c 各个模块都是独立的任务 xff0c 那么任务之间怎么进行大量的数据通信呢 xff1f 在V10版本给出了三种方法 队列queue xff0c 发送固定长度的数据串strea
  • stlink故障修复

    前言 一直用的是国产版stlink xff0c 但是最近手头手头上的两个stlink在下载的时候出故障了 xff0c 无法识别 上淘宝一搜发现涨价了 xff0c 记得以前是20左右 xff0c 现在都要40快一个 于是想着能不能进行修复 百
  • ssd.pytorch源码分析(三)— 非极大值抑制NMS

    NMS源码 SSD论文链接 NMS介绍 吴恩达对于NMS xff08 非极大值抑制 xff09 的介绍 xff1a 说白了 xff0c NMS的作用就是去掉目标检测任务重复的检测框 例如 xff0c 一个目标有多个选择框 xff0c 现在要