pytorch FX模型静态量化

2023-11-18


前言

以前面文章写到的mobilenet图像分类为例,本文主要记录一下pytorchh训练后静态量化的过程。


一、pytorch静态量化(手动版)

静态量化是最常用的量化形式,float32的模型量化成int8,模型大小大概变为原来的1/4,推理速度我在intel 8700k CPU上测试速度正好快4倍,但是在AMD的5800h CPU 上测试速度反而慢了两倍,应该是AMD不支持某些指令集加速。

踩坑:

之前手动添加量化节点的方式搞了好几天,最后模型是出来了,但是推理时候报错,大多数时候是RuntimeError: Could not run ‘--------’ with arguments from the ‘CPU’ backend,网上是说推理的时候没有安插QuantStub()和DeQuantStub(),可能是用的这个MobilenetV3网络结构复杂,某些地方没有手动添加到,这种方式肯定是可以成功的,只是比较麻烦容易出错。

# 加载模型
model = MobileNetV3_Large(2).to(device)  # 加载一个网络,我这边是二分类传了一个2
checkpoint = torch.load(weights, map_location=device)
model.load_state_dict(checkpoint)
model.to('cpu').eval()

合并层对于一些重复使用的Block和nn.Sequential要打印出来看,然后append到mix_list 里面
比如

# 打印model
for name, module in model.named_children():
	print(name, module)

比如这里Sequential里面存在conv+bn+relu,append进去的应该是[‘bneck.0.conv1’, ‘bneck.0.bn1’,‘nolinear1’],但是nolinear1是个变量,也就是说某些时候是relu某些时候又不是,这种时候就要一个个分析判断好然后写代码,稍微复杂点就容易出错或者遗漏。

backend = "fbgemm"  # x86平台
model.qconfig = torch.quantization.get_default_qconfig(backend)
mix_list = [['conv1','bn1'], ['conv2','bn2']] # 合并层只支持conv+bn conv+relu conv+bn+relu等操作,具体可以查一下,网络中存在的这些操作都append到mix_list里面
model = torch.quantization.fuse_modules(model,listmix) # 合并某些层
model_fp32_prepared = torch.quantization.prepare(model)
model_int8 = torch.quantization.convert(model_fp32_prepared)

有时候存在不支持的操作relu6这些要替换成relu,加法操作也要替换,最后还要输入一批图像校准模型等

self.skip_add = nn.quantized.FloatFunctional()
# forward的时候比如return a+b 改为return self.skip_add.add(a, b)

一系列注意事项操作完毕,最后推理各种报错,放弃了

二、使用FX量化

1.版本

fx量化版本也有坑,之前在torch 1.7版本操作总是报错搞不定,换成1.12.0版本就正常了,这一点非常重要。

2.代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import torchvision
from torchvision import transforms
from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.quantization import get_default_qconfig
from torch import optim
import os
import time
from utils import load_data
from models.mobilenetv3copy import MobileNetV3_Large


def evaluate_model(model, test_loader, device, criterion=None):
    model.eval()
    model.to(device)
    running_loss = 0
    running_corrects = 0
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        if criterion is not None:
            loss = criterion(outputs, labels).item()
        else:
            loss = 0
        # statistics
        running_loss += loss * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
    eval_loss = running_loss / len(test_loader.dataset)
    eval_accuracy = running_corrects / len(test_loader.dataset)
    return eval_loss, eval_accuracy


def quant_fx(model, data_loader):
    model_to_quantize = copy.deepcopy(model)
    model_to_quantize.eval()
    qconfig = get_default_qconfig("fbgemm")
    qconfig_dict = {"": qconfig}
    prepared_model = prepare_fx(model_to_quantize, qconfig_dict)
    print("开始校准")
    calibrate(prepared_model, data_loader)  # 这是输入一批有代表性的数据来校准
    print("校准完毕")
    quantized_model = convert_fx(prepared_model)  # 转换
    return quantized_model


def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in train_loader:
            model(image)


if __name__ == "__main__":
    cuda_device = torch.device("cuda:0")
    cpu_device = torch.device("cpu:0")
    model = MobileNetV3_Large(2)  # 加载自己的网络
    train_loader, test_loader = load_data(64, 8)  # 自己写一个pytorch加载数据的方法
    
    # quantization
    state_dict = torch.load('./mymodel.pth')  # 加载一个正常训练好的模型
    model.load_state_dict(state_dict)
    model.to('cpu')
    model.eval()
    quant_model = quant_fx(model, train_loader)  # 执行量化代码
    quant_model.eval()
    print("开始验证")
    eval_loss, eval_accuracy = evaluate_model(model=quant_model,
                                              test_loader=test_loader,
                                              device=cpu_device,
                                              criterion=nn.CrossEntropyLoss())
    print("Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(
        -1, eval_loss, eval_accuracy))
    torch.jit.save(torch.jit.script(quant_model), 'outQuant.pth')  # 保存量化后的模型
    
    # 加载量化模型推理
    loaded_quantized_model = torch.jit.load('outQuant.pth')
    eval_loss, eval_accuracy = evaluate_model(model=quant_model,
                                              test_loader=test_loader,
                                              device=cpu_device,
                                              criterion=nn.CrossEntropyLoss())
    print("Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(
        -1, eval_loss, eval_accuracy))

fx量化也不用管里面什么算子不支持之类的,开箱即用,以上代码参考pytorch官网https://pytorch.org/docs/stable/fx.html
最后验证模型精度下降0.02%可以忽略不计,pytorch量化的模型是不支持gpu推理的,只能在arm或者x86平台实现压缩提速。要用cuda的话要上tensorrt+onnx,以后完成了再讲。完整的训练模型量化模型的代码后面会放到github上面。
完整代码:https://github.com/Ysnower/pytorch-static-quant


总结

刚开始搞量化坑比较多,一个是某些操作不支持,合并层麻烦,另外有版本问题导致的报错可能搞很久,觉得有用的各位吴彦祖麻烦送个免费三连

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch FX模型静态量化 的相关文章

  • pytorch index_put_给出运行时错误:“索引”的导数未实现

    这是后续问题这个问题 https stackoverflow com q 65584330 3337089 我尝试使用index put 如建议的答案 https stackoverflow com a 65584479 3337089 但
  • 如何修复输入和参数张量不在同一设备上?

    我看到其他人也遇到此错误 我尝试按照步骤解决 但仍然收到此错误 运行时错误 输入和参数张量不在同一设备上 在 cpu 处找到输入张量 在 cuda 0 处找到参数张量 我运行 model to device 和 input seq to d
  • Pytorch 说 CUDA 不可用(在 Ubuntu 上)

    我正在尝试在我拥有的笔记本电脑上运行 Pytorch 这是一个较旧的型号 但它确实有 Nvidia 显卡 我意识到这可能不足以实现真正的机器学习 但我正在尝试这样做 以便我可以了解安装 CUDA 的过程 我已按照上面的步骤操作安装指南 ht
  • 检查 PyTorch 张量在 epsilon 内是否相等

    如何检查两个 PyTorch 张量在语义上是否相等 考虑到浮点错误 我想知道元素是否仅相差一个小的 epsilon 值 在撰写本文时 这是最新稳定版本 0 4 1 中的一个未记录的函数 但文档位于master unstable branch
  • max_length、填充和截断参数在 HuggingFace 的 BertTokenizerFast.from_pretrained('bert-base-uncased') 中如何工作?

    我正在处理文本分类问题 我想使用 BERT 模型作为基础 然后使用密集层 我想知道这 3 个参数是如何工作的 例如 如果我有 3 个句子 My name is slim shade and I am an aspiring AI Engin
  • Win10 64位上CUDA 12的PyTorch安装

    我需要在我的 PC 上安装 PyTorch 其 CUDA 版本 12 0 pytorch 2 的表 https i stack imgur com X13oS png in In 火炬网站 https pytorch org get sta
  • 如何检查 PyTorch 是否正在使用 GPU?

    如何检查 PyTorch 是否正在使用 GPU 这nvidia smi命令可以检测 GPU 活动 但我想直接从 Python 脚本内部检查它 这些功能应该有助于 gt gt gt import torch gt gt gt torch cu
  • 通过 Conda 安装 PyTorch

    目标 使用 pytorch 和 torchvision 创建 conda 环境 Anaconda 导航器 1 8 3 python 3 6 MacOS 10 13 4 我尝试过的 在Navigator中 创建了一个新环境 尝试安装 pyto
  • RuntimeError:维度指定为 0 但张量没有维度

    我试图使用 MNIST 数据集实现简单的 NN 但我不断收到此错误 将 matplotlib pyplot 导入为 plt import torch from torchvision import models from torchvisi
  • 预训练 Transformer 模型的配置更改

    我正在尝试为重整变压器实现一个分类头 分类头工作正常 但是当我尝试更改配置参数之一 config axis pos shape 即模型的序列长度参数时 它会抛出错误 Reformer embeddings position embeddin
  • 删除 Torch 张量中的行

    我有一个火炬张量如下 a tensor 0 2215 0 5859 0 4782 0 7411 0 3078 0 3854 0 3981 0 5200 0 1363 0 4060 0 2030 0 4940 0 1640 0 6025 0
  • 如何平衡 GAN 中生成器和判别器的性能?

    这是我第一次使用 GAN 我面临着判别器多次优于生成器的问题 我正在尝试重现PA模型来自本文 http openaccess thecvf com content ICCV 2017 papers Sajjadi EnhanceNet Si
  • 使用 KL 散度时,变分自动编码器为每个输入 mnist 图像提供相同的输出图像

    当不使用 KL 散度项时 VAE 几乎完美地重建 mnist 图像 但在提供随机噪声时无法正确生成新图像 当使用 KL 散度项时 VAE 在重建和生成图像时都会给出相同的奇怪输出 这是损失函数的 pytorch 代码 def loss fu
  • 尝试理解 Pytorch 的 LSTM 实现

    我有一个包含 1000 个示例的数据集 其中每个示例都有5特征 a b c d e 我想喂7LSTM 的示例 以便它预测第 8 天的特征 a 阅读 nn LSTM 的 Pytorchs 文档 我得出以下结论 input size 5 hid
  • 从打包序列中获取每个序列的最后一项

    我试图通过 GRU 放置打包和填充的序列 并检索每个序列最后一项的输出 当然我的意思不是 1项目 但实际上是最后一个 未填充的项目 我们预先知道序列的长度 因此应该很容易为每个序列提取length 1 item 我尝试了以下方法 impor
  • pytorch 中的 autograd 可以处理同一模块中层的重复使用吗?

    我有一层layer in an nn Module并在一次中使用两次或多次forward步 这个的输出layer稍后输入到相同的layer pytorch可以吗autograd正确计算该层权重的梯度 def forward x x self
  • 如何计算 CNN 第一个线性层的维度

    目前 我正在使用 CNN 其中附加了一个完全连接的层 并且我正在使用尺寸为 32x32 的 3 通道图像 我想知道是否有一个一致的公式可以用来计算第一个线性层的输入尺寸和最后一个卷积 最大池层的输入 我希望能够计算第一个线性层的尺寸 仅给出
  • Pytorch 损失为 nan

    我正在尝试用 pytorch 编写我的第一个神经网络 不幸的是 当我想要得到损失时遇到了问题 出现以下错误信息 RuntimeError Function LogSoftmaxBackward0 returned nan values in
  • pytorch 的 IDE 自动完成

    我正在使用 Visual Studio 代码 最近尝试了风筝 这两者似乎都没有 pytorch 的自动完成功能 这些工具可以吗 如果没有 有人可以推荐一个可以的编辑器吗 谢谢你 使用Pycharmhttps www jetbrains co
  • PyTorch 中的连接张量

    我有一个张量叫做data形状的 128 4 150 150 其中 128 是批量大小 4 是通道数 最后 2 个维度是高度和宽度 我有另一个张量叫做fake形状的 128 1 150 150 我想放弃最后一个list array从第 2 维

随机推荐

  • spring gateway 的搭建与配置

    步骤 建项目 给主启动类添加Eureka的注解 EnableEurekaClient 添加并配置application yml 第一步 新建gateway的项目 gateway 8205 需要用到的组件
  • el-descriptions的使用

    el descriptions的使用 解释 我们页面有很多无序的列表展示 为了高效得去开发我们得页面 可以借助于这个组件进行适应 图片 代码 template部分
  • MIPI D-PHY介绍(二) FPGA

    MIPI D PHY介绍 二 FPGA 随着移动设备的广泛普及 MIPI D PHY作为其最主要的物理层标准之一 被越来越多地使用在各种嵌入式系统中 本文将详细介绍MIPI D PHY的工作原理和在FPGA设计中的实现方法 MIPI D P
  • 在k8s集群内搭建Prometheus监控平台

    基本架构 Prometheus由SoundCloud发布 是一套由go语言开发的开源的监控 报警 时间序列数据库的组合 Prometheus的基本原理是通过HTTP协议周期性抓取被监控组件的状态 任意组件只要提供对应的HTTP接口就可以接入
  • .NetCore技术研究-ConfigurationManager在单元测试下的坑

    最近在将原有代码迁移 NET Core 代码的迁移基本很快 当然也遇到了不少坑 重构了不少 后续逐步总结分享给大家 今天总结分享一下ConfigurationManager遇到的一个问题 先说一下场景 迁移 NET Core后 已有的配置文
  • 如何使用Visual Studio Code运行C/C++程序

    与Visual Studio 2008 2010 集成开发工具不同 Visual Studio Code只是一个代码编辑器 在Windows环境下 需下载安装 C C 编译器 配置环境等 VS Code才可以编译代码和运行程序 1 下载安装
  • javaScript基础面试题 --- 原型链

    1 原型可以解决什么问题 对象共享属性和共享方法 2 谁有原型 函数有prototype 对象有 proto 3 查找顺序 当查询一个对象的属性时 JavaScript 会首先检查对象自身是否有这个属性 如果对象本身没有该属性 那么 JS
  • 使用python和snapshot备份ElasticSearch索引数据

    该python备份snapshot的索引数据脚本 通过Elasticsearch连接es 然后通过es indices get alias函数获取所有索引名称 通过列表的startswith函数剔除 开头的自带索引名称 然后把所有索引名称放
  • 多边形的面积

    1 三角形面积 xy平面内 有三角形123 如下图所示 图1 借助矢量叉积和点积 这个三角形的面积公式非常简单 这个面积是有符号的 1 2 3逆时针排列 则面积为正 1 2 3顺时针排列 则面积为负 这是对右手系的总结 如果从背面看这个坐标
  • 11月11日 自定义Events,将自定义Events分配给UI,给UI添加动画 UE4斯坦福 学习笔记

    自定义Events 在AttributeComponent的 h头文件上加上代码 自定义Event DECLARE DYNAMIC MULTICAST DELEGATE FourParams FOnHealthChanged AActor
  • 思科模拟器简单校园网设计,期末作业难度

    文章简介 本文用思科模拟器设计和规划了一个校园网络 相当于计算机网络相关专业期末作业难度 作者简介 网络工程师 希望能认识更多的小伙伴一起交流 可私信或QQ号 1686231613 一 网络需求分析 1 学校建有办公室 实验室 教学楼 学生
  • 【STM32】RS485通信使用DMA串口发送数据出现数据丢失、断包问题排查方法

    最近在搞这个Modbus协议 由于485协议是半双工的 区别于RS 232的全双工 考虑不周导致调试modbus协议时候出了不少问题 第一 大多数开发板上的485芯片是MAX485 发送和接收状态的切换是通过IO给到这个两个引脚不同的电平进
  • win 11又更新,新功能简直绝了!

    很早之前 咱就知道微软下半年将会有一次大动作 没错 就是发布Win11 22H2正式版 之前有说过9月份发 现在也确实做到了 微软现在已经面向190多个国家 地区推送了Windows 11 22H2正式版更新 更新之后版本号为22621 5
  • linux中通过sed命令通过正则表达式过滤出中文[^[\u4E00-\u9FA5A-Za-z0-9_]+$]

    linux中通过sed命令通过正则表达式过滤出中文 sed r s u4E00 u9FA5A Za z0 9 lt gt 0 9 a z A Z g zz txt gt a txt
  • flutter listview 滚动到底部_(五) Flutter入门学习 之 Widget滚动

    列表是移动端经常使用的一种视图展示方式 在Flutter中提供了ListView和GridView 为了可能展示出更好的效果 我这里提供了一段Json数据 所以我们可以先学习一下Json解析 一 JSON读取和解析 在开发中 我们经常会使用
  • sql注入原理及解决方案

    sql注入原理就是用户输入动态的构造了意外sql语句 造成了意外结果 是攻击者有机可乘 SQL注入 SQL注入 就是通过把SQL命令插入到Web表单递交或输入域名或页面请求的查询字符串 最终达到欺骗服务器执行恶意的SQL命令 比如先前的很多
  • 随机练习题:浅浅固定思路

    1 牛牛的10类人 2 牛牛的四叶玫瑰数 3 牛牛的替换 4 牛牛的素数判断 笔者开头感想 如今大部分高校已经开学 当然笔者也不列外 但是由于疫情的原因 笔者被迫在家上网课学习 一脸忧愁 而这恰恰给了笔者自学的机会 相信笔者会加油滴 按照时
  • Acwing-4366. 上课睡觉

    假设最终答案为每堆石子均为cnt个 cnt一定可以整除sum 石子的总数 我们可以依次枚举答案 sum小于等于10 6 所以cnt的数量等于sum约数的个数 10 6范围内 约数最多的数为720720 它的约数个数有240个 int范围内
  • 单边带(SSB)调制技术

    文章目录 单边带 SSB 调制技术 1 双边带简述 2 单边带调制 单边带 SSB 调制技术 1 双边带简述 首先简述一下双边带调制 所谓双边带 DSB double sideband 调制 本质上就是调幅 时域上将基带信号x t 和高频载
  • pytorch FX模型静态量化

    文章目录 前言 一 pytorch静态量化 手动版 踩坑 二 使用FX量化 1 版本 2 代码如下 总结 前言 以前面文章写到的mobilenet图像分类为例 本文主要记录一下pytorchh训练后静态量化的过程 一 pytorch静态量化