神经网络中FLOPs和MACs的计算(基于thop和fvcore.nn)

2023-11-20

输入为(1,1,200,3)的张量

卷积取

nn.Conv2d(1, 64, kernel_size=(8, 1), stride=(2, 1), padding=(0, 0))

为例。

先计算输出的形状

公式为 \frac{n+2p-k}{s}+1

H上为(200+0-8)/2+1=97

W上依然是3

所以输出的形状是(1,64,97,3)

卷积的本质是wx+b,

但是实际计算过程中,是直接w和x一一对应的乘起来,并且将结果都加起来

计算FLOPs时,一般会忽略b,而MACs并不会忽略b

所以对于一个卷积,对应的FLOPs为

97*3*(8*1)*64=148992

而对应的MACs为

97*3*(8*1)*64+97*3*64=167616

后一个97*3*64,就是对应b的数量

用代码计算的话,可以用thop计算MACs,fvcore.nn计算FLOPs

import torch
import torch.nn as nn
import torch.nn.functional as F
from thop import profile
from fvcore.nn import FlopCountAnalysis, parameter_count_table
class net(torch.nn.Module):
    def __init__(self, image_channels=1, n_classes=6):
        super(net, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(image_channels, 64, kernel_size=(8, 1), stride=(2, 1), padding=(0, 0)),
        )

    def forward(self, x):
        cnn_x = self.cnn(x)
        return cnn_x

model = net()
x = torch.randn(1, 1, 200, 3)

macs, params = profile(model, inputs=(x, ))  # ,verbose=False
print("MACs", macs)
print("p", params)

print("@@@@@@@@@@@@@@")

flops = FlopCountAnalysis(model, x)
print("FLOPs", flops.total())
print(parameter_count_table(model))

打印的结果为

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.
[WARN] Cannot find rule for <class '__main__.net'>. Treat it as zero Macs and zero Params.
MACs 167616.0
p 576.0
@@@@@@@@@@@@@@
FLOPs 148992
| name            | #elements or shape   |
|:----------------|:---------------------|
| model           | 0.6K                 |
|  cnn            |  0.6K                |
|   cnn.0         |   0.6K               |
|    cnn.0.weight |    (64, 1, 8, 1)     |
|    cnn.0.bias   |    (64,)             |

对应的WARN不用理会,因为这几个类本来就没有计算

其实计算量的判定还有MACCs,MADDs等方法,具体可以参考

CNN的参数量、计算量(FLOPs、MACs)与运行速度_Dr鹏的博客-CSDN博客_模型复杂度

本文参考

网络模型计算量评估_WTHunt的博客-CSDN博客_网络计算量

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

神经网络中FLOPs和MACs的计算(基于thop和fvcore.nn) 的相关文章

  • 人脸识别趟坑历程

    1 人脸识别概述 人脸识别 是基于人的脸部特征信息进行身份识别的一种生物识别技术 用摄像机或摄像头采集含有人脸的图像或视频流 并自动在图像中检测和跟踪人脸 进而对检测到的人脸进行脸部的一系列相关技术 其中技术包括图像采集 特征定位 身份的确
  • jmeter3.3调用数据库写存储过程注意点

    1 数据库配置页面 2 创建存储过程要保证库里没有同名 本来这句drop语句放在创建存储过程里面的 发现会导致不会执行存储过程 一定要分开写 query type选择 update statement 3 创建存储过程 variable n
  • 5-6)视图与索引

    文章目录 视图 一 视图概述 1 视图的定义 2 视图的分类 3 视图的优缺 二 创建视图 使用T SQL 语句创建视图 创建视图注意的问题 三 使用视图 视图 视图是一种虚拟的数据表 Virtual table 来源于数据表和其他数据 一
  • 实验一:交换机的配置与管理-计算机网络

    交换机的配置与管理 技术原理 交换机的管理方式基本分为两种 带内管理和带外管理 通过交换机的Console端口管理交换机属于带外管理 这种管理方式不占用交换机的网络端口 第一次配置交换机必须利用Console端口进行配置 交换机的命令行操作
  • Python爬虫学习之数据提取(XPath)

    Python爬虫学习之数据提取XPath 概述 常用规则 运算符及介绍 准备工作 实例 文本获取 属性获取 属性值匹配 属性多值匹配 多属性匹配 按序选择 概述 XPath的全称是XML Path Language 即XML路径语言 用来在

随机推荐

  • url参数加密_百度逆推link?url=xxx加密算法“反推技术秒收"

    熟悉百度的站长都知道 凡是被百度搜索引擎收录的网站链接 都会生成一个以baidu开头的多参数跳转链接 而所谓 百度反推技术 的原理就是把百度生成的这个链接地址换成自己想要被收录的页面链接就可以了 然后再进行百度快照的投诉 就可以达到秒来蜘蛛
  • SQL:exec sp_executesql 用法

    這種是無效的過程 declare sql nvarchar 500 where nvarchar 500 i nvarchar 64 p nvarchar 50 id int set id 5 set sql select p AreaCo
  • 机器学习系列(8):人脸识别基本原理及Python实现

    众所周知 人脸识别和人脸验证已经得到大量应用 那么它们之间有什么异同呢 又是如何实现的呢 这里是机器学习系列第八篇 带你揭开它们神秘的面纱 若图片挂了 可移步 https mp weixin qq com s biz MzU4NTY1NDM
  • STC仿真失败

    原因就是购买的下载工具不适合在烧写STC8H3K64S仿真固件后再将该下载工具作为USB转串口工具连接PC与目标板 推测是接入仿真时会重启目标板 不打算细究 换一个普通串口就好了
  • VS编译.cu文件源文件无法打开matrix.h和mex.h问题

    配置好cu和VS相关库文件后CUDA程序仍然报错 无法打开matrix h和mex h 解决办法 1 这两个头文件是matlab中的 可能无法直接在VS中调用 可以通过添加外部依赖项的方法将matlab中的头文件的文件路径添加进来 VS中按
  • 机器学习:聚类算法API初步使用

    学习目标 知道聚类算法API的使用 1 api介绍 sklearn cluster KMeans n clusters 8 参数 n clusters 开始的聚类中心数量 整型 缺省值 8 生成的聚类数 即产生的质心 centroids 数
  • SQL批量处理+JDBC操作大数据及工具类的封装

    SQL批量处理 JDBC操作大数据及工具类的封装 一 批处理 批量处理sql语句 在jdbc的url中添加rewriteBatchedStatements true参数 可以提高批处理执行效率 在我们进行大批量数据操作的时候 需要采用批处理
  • 不使用80,443,端口,域名还需要备案吗?域名没有备案应该怎么选服务器。

    在互联网日益发达的今天 越来越多的个人 企业 公司涌入其中 在服务器 域名 大量供需的情况下身为一个小白应该要注意什么呢 首先要明确你所需要的服务器是国内大陆服务器 如 杭州 扬州 镇江 宁波等 还是海外服务器 如 香港 美国 日本 韩国等
  • uniapp写h5如何封装一个图片上传预览并且有进度条的组件

    开发背景 首先项目是用uniapp写的h5项目 要求能上传 预览 和进度条展示 还要求总览的时候用缩略图 点开预览要原图 不得不吐槽一下 开发环境 uniapp 阿里云存储 先看截图效果 好了直接上代码 photo picker vue
  • Python基础知识点梳理

    本文简要梳理了Python基础知识的大体框架 目录 一 变量和赋值 二 分支和循环 1 分支结构 2 循环结构 三 数据结构 四 函数 lambda函数 匿名函数 五 面向对象 1 封装 2 继承 六 模块和包 一 变量和赋值 变量是编程语
  • Python 基于 opencv 车牌识别系统的研究与实现

    源码下载地址 https download csdn net download gdutxiaoxu 87419195 原理简介 车牌字符识别使用的算法是opencv的SVM opencv的SVM使用代码来自于opencv附带的sample
  • 数据湖--概念、特征、架构与案例概述

    一 什么是数据湖 数据湖是目前比较热的一个概念 许多企业都在构建或者计划构建自己的数据湖 但是在计划构建数据湖之前 搞清楚什么是数据湖 明确一个数据湖项目的基本组成 进而设计数据湖的基本架构 对于数据湖的构建至关重要 关于什么是数据湖 有如
  • 上传图片到七牛云

    public JSONObject uploadImgToQiniu RequestParam MultipartFile file HttpServletResponse response HttpServletRequest reque
  • 论文笔记: Masked Autoencoders Are Scalable Vision Learners

    1 整体思路 效仿BERT中MLM的思路 随机mask掉输入图像的部分patch 并重建这些被mask掉的patch 机器学习笔记 ELMO BERT UQI LIUWJ的博客 CSDN博客 模型结构是一个非对称的encoder decod
  • 生命在于折腾——SQL注入的实操(五)less21-25

    一 实操环境 1 操作系统 VMware虚拟机创建的win10系统 内存8GB 硬盘255GB 处理器AMD Ryzen 9 5900HX 2 操作项目 sql lib项目 本篇文章介绍关卡21 25 3 工具版本 phpstudy 8 1
  • 百度翻译接口获取过程

    记百度翻译接口获取过程 coding utf 8 usr bin env python 思路 进入到百度翻译 https fanyi baidu com 首先要找到返回数据的接口 打开f12 输入你要翻译的内容后能看到很多请求如图所示 进入
  • Qt播放音乐报错DirectShowPlayerService::doSetUrlSource: Unresolved error code 0x80070002 ()

    需求 在Qt中播放背景音乐 代码片段如下 1 pro添加组件 QT multimedia 2 使用 QMediaPlayer 对象实现播放音乐 循环播放背景音乐 void ClearApp playBG QMediaPlayer playe
  • Qt自定义图片按钮并设置方向

    Qt自定义图片按钮 设置方向 虽然Qt定义了很多很多控件 但是还是不能满足用户的需要 比如如果想使用ToolButton 需要带文字 又需要文字可以设定位置 显然就不行了 下面的代码就是一个简单的实现ToolButton功能 并且能够设置图
  • 【模拟电路】三极管做开关,各个电阻的作用

    下面介绍用NPN做开关来驱动蜂鸣器的用法 对各个电阻的用法的解释 图一 这个比较简单 R20是限流作用 R21也是限流作用 图二 相同的地方就不说了 不同的地方是在基极和发射极之间加了一个电阻 这个电阻主要有两个作用 作用一 防止三极管因为
  • 神经网络中FLOPs和MACs的计算(基于thop和fvcore.nn)

    以 输入为 1 1 200 3 的张量 卷积取 nn Conv2d 1 64 kernel size 8 1 stride 2 1 padding 0 0 为例 先计算输出的形状 公式为 H上为 200 0 8 2 1 97 W上依然是3