yolo 推理 nms

2023-11-10

测试代码 

另外一个说明cv2绘制不了中文,但可以用其他包实现。

from pathlib import Path

import cv2
import torch

from models.common import DetectMultiBackend
from utils.dataloaders import LoadImages
from utils.general import Profile, increment_path, non_max_suppression, scale_boxes
from utils.plots import Annotator
from utils.torch_utils import select_device

device = 'cpu'
weights = 'D:\PycharmProjects\swallow\wights\yolov5s.pt'
device = select_device(device)
dnn = False
half = False
data = 'D:\PycharmProjects\swallow\config\coco128.yaml'

model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)

source = 'D:\PycharmProjects\swallow\data\images'
imgsz = (640, 640)
stride = 32
pt = True
vid_stride = 1
bs = 1  # batch_size
conf_thres = 0.25  # confidence threshold
iou_thres = 0.45  # NMS IOU threshold
classes = [0, 1, 2, 3, 4]
agnostic_nms = False
max_det = 1000
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))  # warmup
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
for i, (path, im, im0s, vid_cap, s) in enumerate(dataset):
    with dt[0]:
        im = torch.from_numpy(im).to(model.device)
        im = im.half() if model.fp16 else im.float()  # uint8 to fp16/32
        im /= 255  # 0 - 255 to 0.0 - 1.0
        if len(im.shape) == 3:
            im = im[None]  # expand for batch dim
    with dt[1]:
        pred = model(im, augment=True, visualize=False)
        # NMS
    with dt[2]:
        pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
    print(f'预测数据:{pred}')
    for i, det in enumerate(pred):  # per image
        p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
        p = Path(p)  # to Path
        det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
        for d in det:
            cv2.rectangle(im0, (int(d[0]), int(d[1])), (int(d[2]), int(d[3])), (0, 0, 255), 2)
    cv2.imshow('name', im0)
    cv2.waitKey(0)

DetectMultiBackend:

支持各种模型推理:

# Usage:
#   PyTorch:              weights = *.pt
#   TorchScript:                    *.torchscript
#   ONNX Runtime:                   *.onnx
#   ONNX OpenCV DNN:                *.onnx --dnn
#   OpenVINO:                       *_openvino_model
#   CoreML:                         *.mlmodel
#   TensorRT:                       *.engine
#   TensorFlow SavedModel:          *_saved_model
#   TensorFlow GraphDef:            *.pb
#   TensorFlow Lite:                *.tflite
#   TensorFlow Edge TPU:            *_edgetpu.tflite
#   PaddlePaddle:                   *_paddle_model

1.首先根据文件后缀判断文件类型。

pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)

2:初始化模型

        elif jit:  # TorchScript
            LOGGER.info(f'Loading {w} for TorchScript inference...')
            extra_files = {'config.txt': ''}  # model metadata
            model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
            model.half() if fp16 else model.float()
            if extra_files['config.txt']:  # load metadata dict
                d = json.loads(extra_files['config.txt'],
                               object_hook=lambda d: {int(k) if k.isdigit() else k: v
                                                      for k, v in d.items()})
                stride, names = int(d['stride']), d['names']

3:forward调用模型

        elif self.jit:  # TorchScript
            y = self.model(im)

结合export.py 工具,可以导出不同的模型,运行不同形式的模型。

Detect:

训练时候的损失函数:

https://mp.csdn.net/mp_blog/creation/editor/128985650

                pxy = pxy.sigmoid() * 2 - 0.5
                pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]

推理还原代码:

self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
xy = (xy * 2 + self.grid[i]) * self.stride[i]  # xy
wh = (wh * 2) ** 2 * self.anchor_grid[i]  # wh
y = torch.cat((xy, wh, conf), 4)

解释:

yolo模型是基于特征金字塔。比如原始图片大小(640, 480), 那么他会按步长(8, 16, 32)下降得到新的三张特征图[(80, 60), *(40, 30), ,]。 那么还原回去是不是也应该乘以步长, 其实从损失函数可以看出,模型预测的只是一个偏移。所以还原回去,按照原定方式还原就行了。

模型输出:

z.append(y.view(bs, self.na * nx * ny, self.no))

本来应该是(1, 3, 80, 60, 85)  含义是:有一张图片,把它分成 (80, 60)的网格,每个网格有3个先验框。每个先验框预测 box(x, y, w ,h) 4 + 置信度 (1)+ 类别热编码(80)。

推理的时候我们只关心,预测的物体。所以view了一下。含义为:预测了几张图片,总共预测了多少物体(其中大部分是背景,因为存在3张特征图,预测量是非常恐怖的)

nms:

        1: 根据置信度,过滤大量的背景或者不符合的预测值

xc = prediction[..., 4] > conf_thres  # candidates

    for xi, x in enumerate(prediction):  # image index, image inference
        x = x[xc[xi]]  # confidence

       2:box坐标转换

box = xywh2xyxy(x[:, :4]) 

       3: 计算得分,得到预测类别最高得分, 过滤掉不符合的类别

             类别的得分,是置信度 * 类别概率的综合分数。但是判别标准还是置信度阈值。

 x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

 conf, j = x[:, 5:mi].max(1, keepdim=True)
 x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

   4:根据置信度排序

x = x[x[:, 4].argsort(descending=True)]  # sort by confidence

5:计算nms

boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS

参考资料

NMS(非极大值抑制)_zouxiaolv的博客-CSDN博客_非极大值抑制

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

yolo 推理 nms 的相关文章

随机推荐

  • 多维时序

    多维时序 MATLAB实现CNN BiLSTM Attention多变量时间序列预测 目录 多维时序 MATLAB实现CNN BiLSTM Attention多变量时间序列预测 预测效果 基本介绍 模型描述 程序设计 参考资料 预测效果 基
  • Python图像相似度2种方法和嵌入空间度量学习

    图像相似度 方法 1 在本文中 我们将介绍如何使用图像相似性量度库来比较图像 根据库的文档 我们可以使用八种不同的评估指标来计算图像之间的相似度 幸运的是 所有可怕的数学运算已为我们实现 我们可以立即开始测量图像相似度 我们只需要调用所选评
  • Lightroom无法在卷计算机上,lightroom无法正常启动怎么办?解决lightroom无法启动方法...

    lightroom在图片的后期处理中占有相当重要的地位 很多用户反映他们的lightroom打不开了 有的是安装完成就无法使用 有一部分是之前能使用突然打不开 这其中又分为有警告框和无警告框 影响了工作进度 想了很多办法都未能解决 那么li
  • 编译SSH代码时,报错configure:error:*** working libcrypto not found,check config.log的原因分析及解决方案

    在将ssh移植到龙芯1B核心板的过程中 当编译openssh 8 0p1工具时 出现了 configure error working libcrypto not found check config log 的报错 根据提示 在opens
  • 目前为止最全的微信小程序项目实例

    wx gesture lock 微信小程序的手势密码 WXCustomSwitch 微信小程序自定义 Switch 组件模板 WeixinAppBdNovel 微信小程序demo 百度小说搜索 shitoujiandaobu 小程序 石头剪
  • BeanUtils.copyProperties,忽略目标对象中不为空的字段

    方法 copyProperties Object source Object target String ignoreProperties 要求 复制对象时 目标对象中不为空的数据 使用 BeanUtils copyProperties s
  • jsp页面中JSTL/EL标签引用java后台静态static字段的方法总结

    为什么使用该功能 项目中的每个页面都包含产品名称 Logo 版本等信息 我希望修改一处 其它所有的全部跟着变 有同学会说那就都引用一个页面 就Ok了 但是我希望这些信息都是可以通过后台代码修改的 修改后保存到数据库和一个静态类中 其实后台直
  • [499]openstack swift 的UI客户端

    了解一下cloudyberry提供的openstack swift客户端 分为收费版和免费版 主要试用了一下免费版 做的还是蛮精致的 很大程度上方便了我们上传 下载 浏览swift上的文件 非常好用 在这里推荐一下 cloudyberry下
  • 华为OD机试 - 斗地主之顺子(Java)

    题目描述 在斗地主扑克牌游戏中 扑克牌由小到大的顺序为 3 4 5 6 7 8 9 10 J Q K A 2 玩家可以出的扑克牌阵型有 单张 对子 顺子 飞机 炸弹等 其中顺子的出牌规则为 由至少5张由小到大连续递增的扑克牌组成 且不能包含
  • 每日一道面试题之介绍一下4+1视图模型!

    4 1视图模型是一种用于软件系统设计和开发的模型 它由4个逻辑视图和一个场景视图组成 每个视图都关注系统的不同方面 为的就是尽可能实现一个全面的系统设计 逻辑视图 描述了软件系统的功能和业务逻辑 它包括了系统的结构和组件之间的关系 以及它们
  • 二叉树的中序遍历(C语言)

    我们从两个方向讲解二叉树的中序遍历 递归 迭代 一 递归 思想 从根节点开始向其的左孩子遍历 一直访问每个节点的左孩子 当其走到NULL时返回 返回时记录每个节点的数值 然后访问该节点的右孩子 如果为NULL直接返回上一层 如果不为NULL
  • Twins: Revisiting the Design of Spatial Attention inVision Transformers解读

    文章 https arxiv org abs 2104 13840 代码 GitHub Meituan AutoML Twins Two simple and effective designs of vision transformer
  • Protobuf Java (2)

    接上一篇文章 Protobuf Java 1 接下来写一个demo 使用protobuf 读写消息 目录 1 写消息 2 读一个消息 3 扩展Protocol Buffer 1 写消息 现在让我们尝试使用协议缓冲区类 您希望地址簿应用程序能
  • CentOS7主机名的查看和修改

    CentOS7主机名的查看和修改 在CentOS7中 有三种定义的主机名 静态的 Static hostname 静态 主机名也称为内核主机名 是系统在启动时从 etc hostname自动初始化的主机名 瞬态的 Tansient host
  • Ping 命令

    PING Packet Internet Groper 因特网包探索器 Ping命令是Windows系列自带的一个可执行命令 利用它可以检查网络是否能够连通 并且能够帮助我们分析判定网络故障 ping的发送和接收 同一个子网中的源主机对目的
  • html ui组件,UI组件

    Bootstrap 天然响应式 12分栏 cnpm install bootstrap 安装相关包 在index html中引入文件后才可以用 如下 ElementUI 24分栏 elementUI使用 安装 element ui cnpm
  • Django 启动报错 mysqlclient 1.4.0 or newer is required; you have 0.9.3

    报错原因 MySQLclient 目前只支持到 Python3 4 这里使用了更高版本的 python 那么需要 我们在Django 配置文件目录下 也就是setting py 同级目录下 配置指定版本的mysqlclient pymysq
  • Flowable工作流引擎的使用2(BPMN结构及节点介绍)

    Flowable工作流引擎的使用 2BPMN结构介绍 上一篇讲到了flowable如何使用 用了一个简单的demo 演示了一下流程的创建 发起 审核 查询等功能 内容不多但是引申出很多的概念 BPMN deployId processId
  • 数据分析笔记—数据仓库篇

    数据仓库 数据仓库 Data Warehouse 可简写为DW或DWH 数仓等 它仅适用于查询和分析 通常涉及大量的历史数据 数据仓库中的数据一般来自应用日志文件 数据埋点 和事务应用 实际发生的业务记录的数据 等广泛来源 一个数据仓库通常
  • yolo 推理 nms

    测试代码 另外一个说明cv2绘制不了中文 但可以用其他包实现 from pathlib import Path import cv2 import torch from models common import DetectMultiBac