DETR,YOLO模型计算量(FLOPs)参数量(Params)

2023-11-04

前言

关于计算量(FLOPs)参数量(Params)的一个直观理解,便是计算量对应时间复杂度,参数量对应空间复杂度,即计算量要看网络执行时间的长短,参数量要看占用显存的量。

计算量: FLOPs,FLOP时指浮点运算次数,s是指秒,即每秒浮点运算次数的意思,考量一个网络模型的计算量的标准。越小越好

参数量: Params,是指网络模型中需要训练的参数总数。越小越好

在这里插入图片描述

了解以上概念后,接下来便是如何计算这两个值。
一个很常见的方法便是通过ptflos包来实现。

# -- coding: utf-8 --
import torchvision
from ptflops import get_model_complexity_info

model = torchvision.models.alexnet(pretrained=False)
flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True)
print('flops: ', flops, 'params: ', params)

这段代码可以说是即插即用。

DAB-DETR模型

博主以DAB-DETR模型为例,运行时报错,这是由于权重文件于模型配置文件不匹配导致的

权重文件与模型配置不匹配

RuntimeError: Error(s) in loading state_dict for DABDeformableDETR:
	size mismatch for input_proj.0.0.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 128, 1, 1]).
	size mismatch for input_proj.1.0.weight: copying a param with shape torch.Size([256, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 1, 1]).
	size mismatch for input_proj.2.0.weight: copying a param with shape torch.Size([256, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 512, 1, 1]).
	size mismatch for input_proj.3.0.weight: copying a param with shape torch.Size([256, 2048, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 3, 3]).

修改num_channels的值即可,原本为【128,256,512】

  if return_interm_layers:
        # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
        return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
        self.strides = [8, 16, 32]
        self.num_channels = [512, 1024, 2048]

推理代码

推理代码如下:几乎所有的DETR类模型的推理代码都是可以通用的。

import json
import os, sys
import torch
import numpy as np

from models import build_DABDETR
from models.dab_deformable_detr import build_dab_deformable_detr
from util.slconfig import SLConfig
from datasets import build_dataset
from util.visualizer import COCOVisualizer
from util import box_ops
model_config_path = "D:/graduate/others/DAB-DETR/config.json" # change the path of the model config file
model_checkpoint_path = "D:/graduate/others/DAB-DETR/checkpoint.pth" # change the path of the model checkpoint
# See our Model Zoo section in README.md for more details about our pretrained models.

args = SLConfig.fromfile(model_config_path)
model, criterion, postprocessors = build_DABDETR(args)
checkpoint = torch.load(model_checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
_ = model.eval()
with open('util/coco_id2name.json') as f:
    id2name = json.load(f)
    id2name = {int(k): v for k, v in id2name.items()}
from PIL import Image
import datasets.transforms as T
image = Image.open("./figure/4.jpg").convert("RGB") # load image
# transform images
transform = T.Compose([
    T.RandomResize([800], max_size=1333),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image, _ = transform(image, None)
from ptflops import get_model_complexity_info
model=model.to(args.device)
flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True)
print('flops: ', flops, 'params: ', params)
# predict images
with torch.no_grad():
    output = model.cuda()(image[None].cuda())
  # visualize outputs
output = postprocessors['bbox'](output, torch.Tensor([[1.0, 1.0]]).cuda())[0]
thershold = 0.5  # set a thershold
vslzr = COCOVisualizer()
scores = output['scores']
print(len(scores))
labels = output['labels']
boxes = box_ops.box_xyxy_to_cxcywh(output['boxes'])
select_mask = scores > thershold

box_label = [id2name[int(item)] for item in labels[select_mask]]
pred_dict = {
      'boxes': boxes[select_mask],
      'size': torch.Tensor([image.shape[1], image.shape[2]]),
      'box_label': box_label
}

vslzr.visualize(image, pred_dict, savedir=None, dpi=120)

DN-DETR模型

DN-DETR模型推理代码与DAB-DETR模型推理代码大同小异,但问题却不尽相同。

空值问题

indicator0 = torch.zeros([num_queries * num_patterns, 1]).cuda()
TypeError: unsupported operand type(s) for *: 'int' and 'NoneType'

空值问题,给num_patterns赋值=1即可

CPU与GPU运算问题

boxes = boxes * scale_fct[:, None, :]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

数据有的在cpu上,有的在gpu上,在boxes = boxes * scale_fct[:, None, :]后面加上.cuda()

tuple转换问题

此外,还会报错tuple的转换问题

TypeError: tuple indices must be integers or slices, not str

将下面的代码

out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']

改为:

out_logits=outputs[0]['pred_logits']
out_bbox = outputs[0]['pred_boxes']

参数量计算问题

至此,DN-DETR模型推理代码修改无误,但在计算参数量时却出现问题:

File "D:\Anaconda\envs\deformable_detr\lib\site-packages\ptflops\pytorch_ops.py", line 162, in multihead_attention_counter_hook
    q, k, v = input
ValueError: not enough values to unpack (expected 3, got 2)

这里可以看到报错是参数数量出现了问题,我们找到原来的代码,将q, k, v = input改为:

q, k= input, v=k

GPU与CPU运算问题

同样的,这里也报了数据计算位置不一致的问题,如法炮制即可。

 File "E:\graduate\papers\DN-DETR\DN-DETR-main\models\DN_DAB_DETR\DABDETR.py", line 458, in forward
    boxes = boxes * scale_fct[:, None, :]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

DN-DAB-Deformable-DETR模型

参数量运算问题

由于DN-DAB-Deformable-DETR与DN-DAB-DETR共用一套代码,这里出了问题。

    q, k= input
ValueError: too many values to unpack (expected 2)

我们查看一下input的长度,共有三个值,那么原本的写法就没有问题了,改为原本写法即可。

q, k, v= input

报错batch-size问题,其实很好解决,因为我们只是推理,只有一张图片,那么只需要设置为1即可。

至此,DETR类模型推理与计算量,参数量计算解决了。

YOLO模型计算

随后便是YOLO模型,其计算方式类似,原本博主将上面的代码直接拿过来用,但发现却出问题了。
参数量始终为0,这让我百思不得其解。

在这里插入图片描述

随后博主换了另一个工具包。

from thop import profile
print('==> Building model..')
input = torch.randn(1, 3, 224,224)
input = input.cuda()
flops, params = profile(model, (input,))
print('flops: %.2f M, params: %.2f M' % (flops / 1e6, params / 1e6))

就OK了,与DETR模型一样,我们将其放到模型推理代码中直接就可以了。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

DETR,YOLO模型计算量(FLOPs)参数量(Params) 的相关文章

  • 简单博弈论(Nim游戏)

    891 Nim游戏 题目 提交记录 讨论 题解 视频讲解 给定 n 堆石子 两位玩家轮流操作 每次操作可以从任意一堆石子中拿走任意数量的石子 可以拿完 但不能不拿 最后无法进行操作的人视为失败 问如果两人都采用最优策略 先手是否必胜 输入格
  • 图像分类之PaddleClas网络预训练模型加载方法

    PaddlePaddle简介 PaddlePaddle是非常好用的深度学习库 尤其是2 0版本发布以来 高低层API可以自由结合使用 优点如下 可以像tensorflow里面的keras一样非常方便的用几行代码完成模型构建和训练 可以像py
  • 【图像处理】彩色图直方图统计

    首先要知道彩色图是没有直方图的 只能在rgb方向分别求直方图在合并一下 干脆不用这么麻烦 用rgb2gray转到灰度图 再在二维上进行直方图绘制 最后还提供了代码 找出直方图中横坐标 像素值 为50以下的纵坐标 以此为像素的个数 的和 cl

随机推荐

  • 代码精简10倍,责任链模式yyds

    目录 什么是责任链 使用场景 结语 前言最近 我让团队内一位成员写了一个导入功能 他使用了责任链模式 代码堆的非常多 bug 也多 没有达到我预期的效果 实际上 针对导入功能 我认为模版方法更合适 为此 隔壁团队也拿出我们的案例 进行了集体
  • K8S中安装kafka集群问题总结

    k8s下kafkacluster的安装 https github com banzaicloud kafka operator 问题一 镜像无法拉取 由于镜像源在国外被墙的原因 无法从源镜像下载 一般走镜像代理的形式 先从代理仓库docke
  • Ubuntu16.04 完全卸载opencv

    cd XXX opencv build 进入build目录 sudo make uninstall 卸载掉配置路径中的文件 sudo rm r build 删除build文件
  • Windows10家庭版 Windows defender 安全中心显示 页面不可用

    前言 今天使用电脑时出现了如下情况 倒没有发现电脑有什么实质的问题 只是不太理解又觉得好奇 于是就上网查了查 因为没有发现电脑有什么实质的问题 所以犯懒没有鼓捣自己的电脑 以下内容皆由网络所得 由本人整理汇总 希望有所帮助 可能的原因与解决
  • IOU(Intersection Over Union) 概念清晰图解 + python代码示例

    IOU Intersection Over Union 交并比 Intersection over Union IoU 目标检测中使用的一个概念 是产生的候选框 candidate bound 与原标记框 ground truth boun
  • C++ gstreamer函数使用总结

    目录 1 GSteamer的基本API的使用 这个播放mp4报错 这个创建play bin 返回0 不能运行 这个值得看 2 创建元件并且链接起来 3 添加衬垫 添加回调 手动链接衬垫 4 打印gstreamer的版本信息 5 gstrea
  • vim-指定区域查找替换

    vim中的区域查找替换 vim这么强大的工具当然是支持只替换一部分文本啦 那么怎么实现呢 最直接的方式 1 用v选中文本 2 然后 这样的话 命令默认形式是 lt gt s source source abc g 繁琐的方法 a bg fr
  • C/S与P2P的主要区别以及相同点

    C S方式所描述的是进程之间服务和被服务的关系 客户是服务的请求方 服务器是服务的提供方 服务的请求方和提供方都要使用网络核心部分所提供的服务 客户程序被用户调用后运行 在通信时主动向远地服务器发起通信 服务请求 因此 客户程序必须知道服务
  • Python爬虫系列(二)——Python爬虫批量下载百度图片

    1 前言 先贴代码 coding utf8 import requests import json from urllib import parse import os import time class BaiduImageSpider
  • 关于LayUI 表格高度解决方案

    需求是这样式的 我有一个产品列表 但是我想在产品列表中显示产品主图信息 本文只涉及LayUI技巧 不涉及JAVA JS 渲染部分 table render cellHeight 300 elem currentTableId url Pro
  • AIX 上压缩与解压缩

    gz gzip d 或 gunzip gzip Z uncompress compress tar tar xvf tar cvf cpio cpio idumv zip unzip 或 jar xvf tar gz gzip dc tar
  • 心跳包实现的另一种机制

    因为工作关系 经常用到心跳包 之前是在服务端中的连接的实体中保持一个timer 每秒加一 每次服务端接到客户端的心跳 就会把计数置为0 当累加到20秒的时候 服务端会接到客户端抛出的掉线函数回调 就会视为客户端掉线 然后从缓存中删掉掉线用户
  • 一次性搞懂什么是AIGC!

    你知道什么是AIGC吗 不知道 没关系 我来告诉你 AIGC就是人工智能生成内容 Artificial Intelligence Generative Content 也就是让AI自己动手创作各种各样的内容 比如图片 视频 音乐 文字等等
  • DNSPod 查看域名解析的 domain_id 和 record_id

    本文介绍调用 API 获取 DNSPod 域名解析需要的 domain id 和 record id 参数的方法 所有的 DNSPod API 请求都必须提供 login token作为公共参数以验证用户身份是否合法 获取 login to
  • 软件版本命名规范

    1 版本命名规范 1 2 3 20190114 rc 由四部分组成 第一位 1 主版本号 当功能模块有较大的变动 比如增加多个模块或者整体架构发生变化 此版本号由项目决定是否修改 第二位 2 子版本号 当功能有一定的增加或变化 比如增加了对
  • [小程序实现保存图片到相册]

    保存图片到相册 实现逻辑 首先查看用户申请过的权限中是否有 保存图片到相册 如果没有这个权限 则需要先申请权限 弹窗授权 如果用户同意授权则保存图片 如果用户不同意 则跳转到设置页 重新授权 然后再保存图片 查看用户申请的全县有哪些 通过微
  • Httpservlet cannot be resolved to a type的原因与解决方法

    刚开始学习Servlet 在Eclipse中新建了一个Servlet 不过页面上报错 Httpservlet cannot be resolved to a type 显然是Eclipse找不到相应的包 即javax servlet 原因
  • 重建控制文件 recreate control file

    简单总结如下 1 启动到mount2 执行Alter database backup controlfile to trace 3 找到第2步生成的trace文件 并作相应修改 只保留创建语句4 shutdown并启动到nomount 执行
  • 2023备战金三银四,Python自动化软件测试面试宝典合集(二)

    马上就又到了程序员们躁动不安 蠢蠢欲动的季节 这不 金三银四已然到了家门口 元宵节一过后台就有不少人问我 现在外边大厂面试都问啥 想去大厂又怕面试挂 面试应该怎么准备 测试开发前景如何 面试 一个程序员成长之路永恒绕不过的话题 每每到这个时
  • DETR,YOLO模型计算量(FLOPs)参数量(Params)

    前言 关于计算量 FLOPs 参数量 Params 的一个直观理解 便是计算量对应时间复杂度 参数量对应空间复杂度 即计算量要看网络执行时间的长短 参数量要看占用显存的量 计算量 FLOPs FLOP时指浮点运算次数 s是指秒 即每秒浮点运