pytorch量化库使用(1)

2023-10-27

量化简介

量化是指以低于浮点精度的位宽执行计算和存储张量的技术。量化模型以降低的精度而不是全精度(浮点)值对张量执行部分或全部运算。这允许更紧凑的模型表示以及在许多硬件平台上使用高性能矢量化操作。与典型的 FP32 模型相比,PyTorch 支持 INT8 量化,从而使模型大小减少 4 倍,内存带宽要求减少 4 倍。与 FP32 计算相比,对 INT8 计算的硬件支持通常快 2 到 4 倍。量化主要是一种加速推理的技术,量化运算符仅支持前向传递。

PyTorch 支持多种量化深度学习模型的方法。大多数情况下,模型在 FP32 中训练,然后模型转换为 INT8。此外,PyTorch 还支持量化感知训练,它使用假量化模块对前向和后向传递中的量化误差进行建模。请注意,整个计算都是以浮点形式进行的。在量化感知训练结束时,PyTorch 提供转换函数,将训练后的模型转换为较低精度。

在较低级别,PyTorch 提供了一种表示量化张量并使用它们执行操作的方法。它们可用于直接构建以较低精度执行全部或部分计算的模型。提供了更高级别的 API,其中包含将 FP32 模型转换为较低精度的典型工作流程,并且精度损失最小。

量化 API 总结

PyTorch 提供两种不同的量化模式:Eager 模式量化和 FX Graph 模式量化。

Eager 模式量化是测试版功能。用户需要进行融合并手动指定量化和反量化发生的位置,而且它仅支持模块而不支持函数。

FX 图形模式量化是 PyTorch 中的一个新的自动量化框架,目前它是一个原型功能。它通过添加对泛函的支持和自动化量化过程来改进 Eager 模式量化,尽管人们可能需要重构模型以使模型与 FX 图形模式量化兼容(通过 进行符号追踪)torch.fx。请注意,FX 图形模式量化预计不适用于任意模型,因为该模型可能无法符号追踪,我们会将其集成到 torchvision 等域库中,并且用户将能够使用 FX 量化与支持的域库中的模型类似的模型图模式量化。对于任意模型,我们将提供一般指南,但要真正使其发挥作用,用户可能需要熟悉torch.fx,特别是如何使模型具有符号可追溯性。

我们鼓励量化的新用户首先尝试 FX 图形模式量化,如果不起作用,用户可以尝试遵循使用 FX 图形模式量化的指南或回退到 eager 模式量化。

下表比较了 Eager 模式量化和 FX Graph 模式量化之间的差异:

 

支持三种类型的量化:

  1. 动态量化(通过以浮点形式读取/存储的激活进行量化的权重并进行量化以进行计算)

  2. 静态量化(权重量化、激活量化、训练后需要校准)

  3. 静态量化感知训练(权重量化、激活量化、训练期间建模的量化数值)

请参阅我们的PyTorch 量化简介博客文章,以更全面地概述这些量化类型之间的权衡。

动态和静态量化之间的运算符覆盖范围有所不同,如下表所示。请注意,对于 FX 量化,还支持相应的泛函。

 Eager Mode 量化

有关量化流程的一般介绍,包括不同类型的量化,请查看一般量化流程

训练后动态量化

这是最简单的量化应用形式,其中权重提前量化,但激活在推理过程中动态量化。这用于模型执行时间主要由从内存加载权重而不是计算矩阵乘法的情况。对于小批量的 LSTM 和 Transformer 类型模型来说确实如此。

# original model
# all tensors and computations are in floating point
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
                 /
linear_weight_fp32

# dynamically quantized model
# linear and LSTM weights are in int8
previous_layer_fp32 -- linear_int8_w_fp32_inp -- activation_fp32 -- next_layer_fp32
                     /
   linear_weight_int8

动态量化 PTDQ API 示例:

import torch

# define a floating point model
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(4, 4)

    def forward(self, x):
        x = self.fc(x)
        return x

# create a model instance
model_fp32 = M()
# create a quantized model instance
model_int8 = torch.ao.quantization.quantize_dynamic(
    model_fp32,  # the original model
    {torch.nn.Linear},  # a set of layers to dynamically quantize
    dtype=torch.qint8)  # the target dtype for quantized weights

# run the model
input_fp32 = torch.randn(4, 4, 4, 4)
res = model_int8(input_fp32)

训练后静态量化

训练后静态量化(PTQ static)量化模型的权重和激活。它尽可能将激活融合到前面的层中。它需要使用代表性数据集进行校准,以确定激活的最佳量化参数。当内存带宽和计算节省都很重要且 CNN 是典型用例时,通常会使用训练后静态量化。

# original model
# all tensors and computations are in floating point
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
                    /
    linear_weight_fp32

# statically quantized model
# weights and activations are in int8
previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8
                    /
  linear_weight_int8

静态量化 PTSQ API 示例:

import torch

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.ao.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()

# model must be set to eval mode for static quantization logic to work
model_fp32.eval()

# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default
# for server inference.
# model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')

# Fuse the activations to preceding layers, where applicable.
# This needs to be done manually depending on the model architecture.
# Common fusions include `conv + relu` and `conv + batchnorm + relu`
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']])

# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)

# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)

静态量化的量化感知训练

量化感知训练 (QAT) 对训练过程中的量化效果进行建模,与其他量化方法相比,具有更高的准确性。我们可以对静态、动态或仅权值量化进行 QAT。在训练期间,所有计算均以浮点形式完成,fake_quant 模块通过钳位和舍入对量化效果进行建模,以模拟 INT8 的效果。模型转换后,权重和激活被量化,并且激活尽可能融合到前一层中。它通常与 CNN 一起使用,并且与静态量化相比具有更高的精度。

# original model
# all tensors and computations are in floating point
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
                      /
    linear_weight_fp32

# model with fake_quants for modeling quantization numerics during training
previous_layer_fp32 -- fq -- linear_fp32 -- activation_fp32 -- fq -- next_layer_fp32
                           /
   linear_weight_fp32 -- fq

# quantized model
# weights and activations are in int8
previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8
                     /
   linear_weight_int8

QAT API 示例:

import torch

# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.ao.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.bn = torch.nn.BatchNorm2d(1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()

# model must be set to eval for fusion to work
model_fp32.eval()

# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default
# for server inference.
# model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')

# fuse the activations to preceding layers, where applicable
# this needs to be done manually depending on the model architecture
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32,
    [['conv', 'bn', 'relu']])

# Prepare the model for QAT. This inserts observers and fake_quants in
# the model needs to be set to train for QAT logic to work
# the model that will observe weight and activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused.train())

# run the training loop (not shown)
training_loop(model_fp32_prepared)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, fuses modules where appropriate,
# and replaces key operators with quantized implementations.
model_fp32_prepared.eval()
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)

静态量化的模型准备

目前有必要在 Eager 模式量化之前对模型定义进行一些修改。这是因为当前量化是逐个模块进行的。具体来说,对于所有量化技术,用户需要:

  1. 将任何需要输出重新量化(因此具有附加参数)的操作从泛函转换为模块形式(例如,使用torch.nn.ReLU代替torch.nn.functional.relu)。

  2. .qconfig通过在子模块上分配属性或指定 来指定模型的哪些部分需要量化 qconfig_mapping。例如,设置意味着该 图层不会被量化,设置 意味着将使用的量化设置而不是全局 qconfig。model.conv1.qconfig = Nonemodel.convmodel.linear1.qconfig = custom_qconfigmodel.linear1custom_qconfig

对于量化激活的静态量化技术,用户还需要执行以下操作:

  1. 指定激活的量化和反量化位置。这是使用 QuantStub和 DeQuantStub模块完成的。

  2. 用于FloatFunctional将需要特殊处理量化的张量运算包装到模块中。例如,诸如addcat之类的操作需要特殊处理来确定输出量化参数。

  3. 熔断模块:将操作/模块组合成单个模块以获得更高的精度和性能。这是使用 fuse_modules()API 完成的,它接收要融合的模块列表。我们目前支持以下融合:[Conv, Relu]、[Conv, BatchNorm]、[Conv, BatchNorm, Relu]、[Linear, Relu]

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch量化库使用(1) 的相关文章

  • 如果两点之间的距离低于某个阈值,则从列表中删除点

    我有一个点列表 只有当它们之间的距离大于某个阈值时 我才想保留列表中的点 因此 从第一个点开始 如果第一个点和第二个点之间的距离小于阈值 那么我将删除第二个点 然后计算第一个点和第三个点之间的距离 如果该距离小于阈值 则比较第一点和第四点
  • 与区域指示符字符类匹配的 python 正则表达式

    我在 Mac 上使用 python 2 7 10 表情符号中的标志由一对表示区域指示符号 https en wikipedia org wiki Regional Indicator Symbol 我想编写一个 python 正则表达式来在
  • 使用特定的类/函数预加载 Jupyter Notebook

    我想预加载一个笔记本 其中包含我在另一个文件中定义的特定类 函数 更具体地说 我想用 python 来做到这一点 比如加载一个配置文件 包含所有相关的类 函数 目前 我正在使用 python 生成笔记本并在服务器上自动启动它们 因为不同的
  • 如何用python脚本控制TP LINK路由器

    我想知道是否有一个工具可以让我连接到路由器并关闭它 然后从 python 脚本重新启动它 我知道如果我写 import os os system ssh l root 192 168 2 1 我可以通过 python 连接到我的路由器 但是
  • Python 中的哈希映射

    我想用Python实现HashMap 我想请求用户输入 根据他的输入 我从 HashMap 中检索一些信息 如果用户输入HashMap的某个键 我想检索相应的值 如何在 Python 中实现此功能 HashMap
  • Pandas 日期时间格式

    是否可以用零后缀表示 pd to datetime 似乎零被删除了 print pd to datetime 2000 07 26 14 21 00 00000 format Y m d H M S f 结果是 2000 07 26 14
  • YOLOv8获取预测边界框

    我想将 OpenCV 与 YOLOv8 集成ultralytics 所以我想从模型预测中获取边界框坐标 我该怎么做呢 from ultralytics import YOLO import cv2 model YOLO yolov8n pt
  • 在Python中连接反斜杠

    我是 python 新手 所以如果这听起来很简单 请原谅我 我想加入一些变量来生成一条路径 像这样 AAAABBBBCCCC 2 2014 04 2014 04 01 csv Id TypeOfMachine year month year
  • 使用 xlrd 打开 BytesIO (xlsx)

    我正在使用 Django 需要读取上传的 xlsx 文件的工作表和单元格 使用 xlrd 应该可以 但因为文件必须保留在内存中并且可能不会保存到我不知道如何继续的位置 本例中的起点是一个带有上传输入和提交按钮的网页 提交后 文件被捕获req
  • Python beautifulsoup 仅限 1 级文本

    我看过其他 beautifulsoup 得到相同级别类型的问题 看来我的有点不同 这是网站 我正试图拿到右边那张桌子 请注意表的第一行如何展开为该数据的详细细分 我不想要那个数据 我只想要最顶层的数据 您还可以看到其他行也可以展开 但在本例
  • 在 Sphinx 文档中*仅*显示文档字符串?

    Sphinx有一个功能叫做automethod从方法的文档字符串中提取文档并将其嵌入到文档中 但它不仅嵌入了文档字符串 还嵌入了方法签名 名称 参数 我如何嵌入only文档字符串 不包括方法签名 ref http www sphinx do
  • 如何通过 TLS 1.2 运行 django runserver

    我正在本地 Mac OS X 机器上测试 Stripe 订单 我正在实现这段代码 stripe api key settings STRIPE SECRET order stripe Order create currency usd em
  • pip 列出活动 virtualenv 中的全局包

    将 pip 从 1 4 x 升级到 1 5 后pip freeze输出我的全局安装 系统 软件包的列表 而不是我的 virtualenv 中安装的软件包的列表 我尝试再次降级到 1 4 但这并不能解决我的问题 这有点类似于这个问题 http
  • 如何断言 Unittest 上的可迭代对象不为空?

    向服务提交查询后 我会收到一本字典或一个列表 我想确保它不为空 我使用Python 2 7 我很惊讶没有任何assertEmpty方法为unittest TestCase类实例 现有的替代方案看起来并不正确 self assertTrue
  • Pandas 将多行列数据帧转换为单行多列数据帧

    我的数据框如下 code df Car measurements Before After amb temp 30 268212 26 627491 engine temp 41 812730 39 254255 engine eff 15
  • 如何解决 PDFBox 没有 unicode 映射错误?

    我有一个现有的 PDF 文件 我想使用 python 脚本将其转换为 Excel 文件 目前正在使用PDFBox 但是存在多个类似以下错误 org apache pdfbox pdmodel font PDType0Font toUnico
  • Scipy Sparse:SciPy/NumPy 更新后出现奇异矩阵警告

    我的问题是由大型电阻器系统的节点分析产生的 我基本上是在设置一个大的稀疏矩阵A 我的解向量b 我正在尝试求解线性方程A x b 为了做到这一点 我正在使用scipy sparse linalg spsolve method 直到最近 一切都
  • Pandas 每周计算重复值

    我有一个Dataframe包含按周分组的日期和 ID df date id 2022 02 07 1 3 5 4 2022 02 14 2 1 3 2022 02 21 9 10 1 2022 05 16 我想计算每周有多少 id 与上周重
  • 更改 Tk 标签小部件中单个单词的颜色

    我想更改 Tkinter 标签小部件中单个单词的字体颜色 我知道可以使用文本小部件来实现与我想要完成的类似的事情 例如使单词 YELLOW 显示为黄色 self text tag config tag yel fg clr yellow s
  • Kivy - 单击按钮时编辑标签

    我希望 Button1 在单击时编辑标签 etykietka 但我不知道如何操作 你有什么想法吗 class Zastepstwa App def build self lista WebOps getList layout BoxLayo

随机推荐

  • matlab函数之reshape()

    reshape 重构数组 功能 B reshape A sz 按矢量sz定义的维度 包括行数 列数 维数 重构矩阵A来得到矩阵B 实现原理 先将矩阵A先排列成一列 结果感受就是按列优先排列 再按照矢量sz定义大小的行数切割 结构及实例 A
  • 区间图着色问题

    这是算法导论贪心算法一章的一个习题 题目描述 假定有一组活动 我们需要将它们安排到一些教室 任意活动都可以在任意教室进行 我们希望使用最少的教室完成所有的活动 设计一个高效的贪心算法求每个活动应该在哪个教室进行 这个问题称为区间图着色问题
  • 在Linux应用程序中打印函数调用栈

    在Linux中打印函数调用栈 要求 在Linux系统的应用程序中写一个函数print stackframe 用于获取当前位置的函数调用栈信息 方法 execinfo h库下的函数backtrace可以得到当前线程的函数调用栈指针和调用栈深度
  • ODOO15固定资产管理系统解决方案(原创)

    有些公司固定资产众多 而且涉及到在建工程的费用归集及在建工程结转固定资产等复杂情况 使用ODOO系统如何来解决这个客户需要解决的问题呢 我们根据自身的实施经验 分享ODOO固定资产的管理解决方案 1 资产分类设置 资产众多 需要进行类别设置
  • 谷歌云GCP

    感谢公司赞助了Google Cloud Platform GCP Coursera课程 https www coursera org 包括云基础设施 应用开发 数据湖和数据仓库相关知识 其中谷歌云的实验操作平台是 https www qwi
  • 数据库系统丨关系代数运算总结

    文章目录 1 需要记忆的符号 2 集合运算 1 并运算 2 差运算 3 交运算 4 广义笛卡尔积 3 关系运算 1 选择 Selection 2 投影 Projection 3 连接 Join 4 除 Division 1 需要记忆的符号
  • VOT 数据集 groundtruth 8个维度 转成 4个维度的方法

    VOT数据集由于加入了带旋转角度的boundingbox 使得其groundtruth的维度达到了8个 如下 8个维度就代表boundingbox的4个点 比如VOT16中 bag数据序列的groundtruth第一行 334 02 128
  • Casual inference 综述框架

    A survey on causal inference 因果推理综述 A Survey on Causal Inference 一文的总结和梳理 因果推断理论解读 Rubin因果模型的三个假设 基础理论 理论框架 名词解释 individ
  • 如何使用IDEA正确的创建一个Web项目

    我是学习java的新人休元 第一次使用CSDN写博客请大家多多关照 写的第一篇博客就是如何使用IDEA正确的创建一个Web项目 刚刚使用IDEA不到两个星期 有很多地方不熟练 如果有错误请大家指出来 操作系统 win10 编译环境 IDEA
  • 2023年第1季社区Task挑战赛开启,等你来战!

    社区Task挑战赛是面向社区开发者开展的代码或教程征集活动 该挑战赛为社区中热爱FISCO BCOS及周边组件的开发者提供了探索区块链技术 挑战技术难题的舞台 该挑战赛去年在社区成功举办了3季 共吸引了数百名开发者报名 前3季都有哪些有趣的
  • 从源码看 AlertDialog.getButton(DialogInterface.BUTTON_POSITIVE) 为什么是 null

    我们在使用 AlertDialog 的时候 如果想改变 POSITIVE BUTTON 或者 NEGATIVE BUTTON 的字体颜色 大小时 可能会注意到 AlertDialog getButton DialogInterface BU
  • 长城网络靶场第三题

    关卡描述 1 oa服务器的内网ip是多少 先进行ip统计 开始逐渐查看前面几个ip 基本上都是b s 所以大概率是http 过滤一下ip 第一个ip好像和oa没啥关系 第二个ip一点开就是 oa 应该就是他了 关卡描述 2 黑客的攻击ip是
  • w7系统如何关闭高级文字服务器,Win7系统怎么取消切换大小写时出现提示?

    Win7系统用户在工作中使用键盘切换大小写输入时 总会弹跳出系统的提示窗口 很多用户觉得非常烦 那么Win7系统应该怎么取消切换大小写时出现的提示呢 接下来下面请看Win7系统切换大小写时出现的提示的具体解决方法 解决方法 1 首先 在桌面
  • 在Vue中获取DOM元素的实际宽高

    最近使用 D3 js 开发可视化图表 因为移动端做了 rem 适配 所以需要动态计算获取图表容器的宽高 其中涉及到一些原生DOM API的使用 避免遗忘这里总结一下 一 获取元素 在 Vue 中可以使用 ref 来获取一个真实的 DOM 元
  • 电商峰值系统架构设计--转载

    1 1 系统架构设计目录 摘要 双11来临之际 程序员 以 电商峰值系统架构设计 为主题 力邀京东 当当 小米 1号店 海尔商城 唯品会 蘑菇街 麦包包等电商企业 及商派 基调网络等服务公司 分享电商峰值系统架构设计的最佳技术实践 自200
  • SSH_Unable to negotiate with ... port ..: nomatching host host key type found. Their offer:ssh-rsa

    终端远程登录ssh时 提示如下错误 Unable to negotiate with 192 168 1 228 port 22 nomatching host host key type found Their offer ssh rsa
  • 常见的损失函数(loss function)总结

    点击上方 小白学视觉 选择加 星标 或 置顶 重磅干货 第一时间送达 导读 本文总结了常见的八种损失函数的优缺点 包括 0 1损失函数 绝对值损失函数 log对数损失函数 平方损失函数 指数损失函数 Hinge 损失函数 感知损失函数 交叉
  • 解答:pytorch 通过索引赋值后 梯度还能正常反向传播吗

    先上测试代码 if name main scene graph token tensor1 torch rand 4 4 tensor1 requires grad True tensor2 torch rand 4 tensor2 req
  • labelme标注不同物体显示不同颜色以及批量转换

    最近在使用labelme标注数据时遇到一些问题 如上图中 蓝色分别为crack person dog三类 正常应该是3种不同颜色 解决方案 1 labelme版本 2 下载labelme后进行文件修改 由于博主想要的是rgb三通道的彩色图
  • pytorch量化库使用(1)

    量化简介 量化是指以低于浮点精度的位宽执行计算和存储张量的技术 量化模型以降低的精度而不是全精度 浮点 值对张量执行部分或全部运算 这允许更紧凑的模型表示以及在许多硬件平台上使用高性能矢量化操作 与典型的 FP32 模型相比 PyTorch