pytorch量化库使用(2)

2023-10-27

FX Graph Mode量化模式

训练后量化有多种量化类型(仅权重、动态和静态),配置通过qconfig_mapping ( prepare_fx函数的参数)完成。

FXPTQ API 示例:

import torch
from torch.ao.quantization import (
  get_default_qconfig_mapping,
  get_default_qat_qconfig_mapping,
  QConfigMapping,
)
import torch.ao.quantization.quantize_fx as quantize_fx
import copy

model_fp = UserModel()

#
# post training dynamic/weight_only quantization
#

# we need to deepcopy if we still want to keep model_fp unchanged after quantization since quantization apis change the input model
model_to_quantize = copy.deepcopy(model_fp)
model_to_quantize.eval()
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig)
# a tuple of one or more example inputs are needed to trace the model
example_inputs = (input_fp32)
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# no calibration needed when we only have dynamic/weight_only quantization
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# post training static quantization
#

model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
model_to_quantize.eval()
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# calibrate (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# quantization aware training for static quantization
#

model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack")
model_to_quantize.train()
# prepare
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)
# training loop (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# fusion
#
model_to_quantize = copy.deepcopy(model_fp)
model_fused = quantize_fx.fuse_fx(model_to_quantize)

量化堆栈

量化是将浮点模型转换为量化模型的过程。因此,在高层次上,量化堆栈可以分为两部分:1)。量化模型的构建块或抽象 2)。将浮点模型转换为量化模型的量化流程的构建块或抽象

量化模型

量化张量

为了在 PyTorch 中进行量化,我们需要能够用张量表示量化数据。量化张量允许存储量化数据(表示为 int8/uint8/int32)以及量化参数(如比例和 Zero_point)。除了允许以量化格式序列化数据之外,量化张量还允许许多有用的操作,使量化算术变得容易。

PyTorch 支持每张量和每通道的对称和非对称量化。每个张量意味着张量内的所有值都使用相同的量化参数以相同的方式量化。每个通道意味着对于每个维度(通常是张量的通道维度),张量中的值使用不同的量化参数进行量化。这可以减少将张量转换为量化值时的错误,因为异常值只会影响其所在的通道,而不是整个张量。

映射是通过使用转换浮点张量来执行的

 

 

 

请注意,我们确保浮点中的零在量化后表示没有错误,从而确保诸如填充之类的操作不会导致额外的量化误差。

以下是量化张量的几个关键属性:

  • QScheme (torch.qscheme):一个枚举,指定我们量化张量的方式

    • torch.per_tensor_affine

    • torch.per_tensor_对称

    • torch.per_channel_affine

    • torch.per_channel_symmetry

  • dtype (torch.dtype):量化张量的数据类型

    • 火炬.quint8

    • 火炬.qint8

    • 火炬.qint32

    • 火炬.float16

  • 量化参数(根据 QScheme 的不同而变化):所选量化方式的参数

    • torch.per_tensor_affine 的量化参数为

      • 刻度(浮动)

      • 零点(整数)

    • torch.per_channel_affine 的量化参数为

      • per_channel_scales(浮点数列表)

      • per_channel_zero_points(整数列表)

      • 轴(整数)

量化和反量化

模型的输入和输出都是浮点张量,但量化模型中的激活是量化的,因此我们需要运算符在浮点和量化张量之间进行转换。

  • 量化(浮点 -> 量化)

    • torch.quantize_per_tensor(x, 尺度, 零点, dtype)

    • torch.quantize_per_channel(x, 尺度, Zero_points, 轴, dtype)

    • torch.quantize_per_tensor_dynamic(x,dtype,reduce_range)

    • 到(火炬.float16)

  • 反量化(量化 -> 浮点)

    • quantized_tensor.dequantize() - 在 torch.float16 张量上调用 dequantize 会将张量转换回 torch.float

    • 火炬.反量化(x)

量化运算符/模块

  • 量化算子是以量化Tensor为输入,输出量化Tensor的算子。

  • 量化模块是执行量化操作的 PyTorch 模块。它们通常是为线性和卷积等加权运算定义的。

量化引擎

当执行量化模型时,qengine (torch.backends.quantized.engine) 指定使用哪个后端来执行。重要的是要确保qengine在量化激活和权重的取值范围方面与量化模型兼容。

量化流程

观察者和 FakeQuantize

  • 观察者是 PyTorch 模块,用于:

    • 收集张量统计信息,例如通过观察者的张量的最小值和最大值

    • 并根据收集的张量统计数据计算量化参数

  • FakeQuantize 是 PyTorch 模块,用于:

    • 模拟网络中张量的量化(执行量化/反量化)

    • 它可以根据观察者收集的统计数据计算量化参数,也可以学习量化参数

查询配置

  • QConfig 是 Observer 或 FakeQuantize Module 类的命名元组,可以使用 qscheme、dtype 等进行配置。它用于配置应如何观察操作员

    • 算子/模块的量化配置

      • 不同类型的 Observer/FakeQuantize

      • 数据类型

      • q方案

      • quant_min/quant_max:可用于模拟较低精度的张量

    • 目前支持激活和权重的配置

    • 我们根据为给定运算符或模块配置的 qconfig 插入输入/权重/输出观察器

一般量化流程

一般来说,流程如下

  • 准备

    • 根据用户指定的 qconfig 插入 Observer/FakeQuantize 模块

  • 校准/训练(取决于训练后量化或量化感知训练)

    • 允许观察者收集统计数据或 FakeQuantize 模块来学习量化参数

  • 转变

    • 将校准/训练模型转换为量化模型

量化有不同的模式,它们可以分为两种方式:

就我们应用量化流程的位置而言,我们有:

  1. Post Training Quantization(训练后应用量化,量化参数根据样本校准数据计算)

  2. 量化感知训练(在训练过程中模拟量化,以便使用训练数据与模型一起学习量化参数)

就我们如何量化运算符而言,我们可以:

  • 仅权重量化(仅权重静态量化)

  • 动态量化(权重静态量化,激活动态量化)

  • 静态量化(权重和激活都是静态量化的)

我们可以在同一量化流程中混合不同的量化运算符方式。例如,我们可以进行具有静态和动态量化运算符的训练后量化。

量化支持矩阵

 

量化定制

虽然提供了观察者根据观察到的张量数据选择比例因子和偏差的默认实现,但开发人员可以提供自己的量化函数。量化可以选择性地应用于模型的不同部分,或者针对模型的不同部分进行不同的配置。

我们还为conv1d()conv2d()、 conv3d()Linear()的每通道量化提供支持。

量化工作流程通过在模型的模块层次结构中添加(例如,将观察者添加为 .observer子模块)或替换(例如,转换nn.Conv2d为 nn.quantized.Conv2d)子模块来工作。这意味着该模型nn.Module在整个过程中保持基于常规的实例,因此可以与 PyTorch API 的其余部分一起使用。

量化自定义模块 API

Eager 模式和 FX 图形模式量化 API 都为用户提供了一个钩子,以指定以自定义方式量化的模块,并使用用户定义的逻辑进行观察和量化。用户需要指定:

  1. 源 fp32 模块的 Python 类型(模型中存在)

  2. 被观察模块的Python类型(由用户提供)。该模块需要定义一个from_float函数,该函数定义如何从原始 fp32 模块创建观察到的模块。

  3. 量化模块的Python类型(由用户提供)。该模块需要定义一个from_observed函数,该函数定义如何从观察到的模块创建量化模块。

  4. 描述上述 (1)、(2)、(3) 的配置,传递给量化 API。

然后框架将执行以下操作:

  1. 在准备模块交换期间,它将使用 (2) 中类的from_float函数将 (1) 中指定类型的每个模块转换为 (2) 中指定的类型。

  2. 在转换模块交换期间,它将使用 (3) 中类的from_observed函数将 (2) 中指定类型的每个模块转换为(3) 中指定的类型。

目前,要求ObservedCustomModule将具有单个 Tensor 输出,并且框架(而不是用户)将在该输出上添加观察者。观察者将作为自定义模块实例的属性存储在activation_post_process键下。未来可能会放宽这些限制。

自定义 API 示例:

import torch
import torch.ao.nn.quantized as nnq
from torch.ao.quantization import QConfigMapping
import torch.ao.quantization.quantize_fx

# original fp32 module to replace
class CustomModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 3)

    def forward(self, x):
        return self.linear(x)

# custom observed module, provided by user
class ObservedCustomModule(torch.nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear

    def forward(self, x):
        return self.linear(x)

    @classmethod
    def from_float(cls, float_module):
        assert hasattr(float_module, 'qconfig')
        observed = cls(float_module.linear)
        observed.qconfig = float_module.qconfig
        return observed

# custom quantized module, provided by user
class StaticQuantCustomModule(torch.nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear

    def forward(self, x):
        return self.linear(x)

    @classmethod
    def from_observed(cls, observed_module):
        assert hasattr(observed_module, 'qconfig')
        assert hasattr(observed_module, 'activation_post_process')
        observed_module.linear.activation_post_process = \
            observed_module.activation_post_process
        quantized = cls(nnq.Linear.from_float(observed_module.linear))
        return quantized

#
# example API call (Eager mode quantization)
#

m = torch.nn.Sequential(CustomModule()).eval()
prepare_custom_config_dict = {
    "float_to_observed_custom_module_class": {
        CustomModule: ObservedCustomModule
    }
}
convert_custom_config_dict = {
    "observed_to_quantized_custom_module_class": {
        ObservedCustomModule: StaticQuantCustomModule
    }
}
m.qconfig = torch.ao.quantization.default_qconfig
mp = torch.ao.quantization.prepare(
    m, prepare_custom_config_dict=prepare_custom_config_dict)
# calibration (not shown)
mq = torch.ao.quantization.convert(
    mp, convert_custom_config_dict=convert_custom_config_dict)
#
# example API call (FX graph mode quantization)
#
m = torch.nn.Sequential(CustomModule()).eval()
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_qconfig)
prepare_custom_config_dict = {
    "float_to_observed_custom_module_class": {
        "static": {
            CustomModule: ObservedCustomModule,
        }
    }
}
convert_custom_config_dict = {
    "observed_to_quantized_custom_module_class": {
        "static": {
            ObservedCustomModule: StaticQuantCustomModule,
        }
    }
}
mp = torch.ao.quantization.quantize_fx.prepare_fx(
    m, qconfig_mapping, torch.randn(3,3), prepare_custom_config=prepare_custom_config_dict)
# calibration (not shown)
mq = torch.ao.quantization.quantize_fx.convert_fx(
    mp, convert_custom_config=convert_custom_config_dict)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch量化库使用(2) 的相关文章

  • 如何获取 Pandas df.merge() 不匹配的列名称

    给出以下数据 data df pd DataFrame Reference A A A B C C D E Value1 U U U V W W X Y Value2 u u u v w w x y index 1 2 3 4 5 6 7
  • 创建一个支持 json 序列化的类以与 Celery 一起使用

    我正在使用 Celery 来运行一些后台任务 其中一项任务返回我创建的 python 类 考虑到有关使用 pickle 的警告 我想使用 json 来序列化和反序列化此类 有没有一种简单的内置方法可以实现这一目标 该类非常简单 它包含 3
  • Django 1.6:清除一张表中的数据

    我有一个名为 UGC 的表 想要清除该表中的所有数据 我不想重置整个应用程序 这也会删除所有其他模型中的所有数据 是否可以只清除一个模型 我还为我的应用程序配置了 South 如果这有帮助的话 你可以使用原始 SQL https docs
  • Python Flask 删除请求

    我正在开发一个 Python 应用程序并使用 Flask 这是我的 DELETE 函数 app route DeleteMessage methods DELETE def DeleteMessage messages Message qu
  • 如何在redis中使用python删除排序集中的项目

    如何使用 python 删除排序集中大于某个值的项目 key foo pipe redis master conn pipeline pipe zadd key 1 a pipe zadd key 2 b pipe zadd key 3 c
  • 在 PyCharm 中启用终端模拟

    很多人告诉过我和PyCharm 2 7 的 PyCharm 发行说明 https www jetbrains com pycharm whatsnew whatsnew 27 html吹捧那个PyCharm包括完整的终端仿真 我认为这是关于
  • 如何在嵌套列表中查找给定元素?

    这是我的迭代解决方案 def exists key arg if not arg return False else for element in arg if isinstance element list for i in elemen
  • 在 Python 中解压存档时出现错误

    我使用 Python 下载 bz2 文件 然后我想使用以下方法解压存档 def unpack file dir file cwd os getcwd os chdir dir print Unpacking file s file cmd
  • 为什么我的字符串中出现不需要的换行符?

    这应该很简单 这很愚蠢 但我无法让它发挥作用 我有一个在读取文件时定义的标头 if gene env in line or gene HIV2gp7 in line header line 现在这个标题看起来像 gt lcl NC 0018
  • Python Jinja2 调用宏会导致(不需要的)换行符

    我的 JINJA2 模板如下所示 macro print if john name if name John Hi John endif endmacro Hello World print if john Foo print if joh
  • 将 postgres 连接到 django 时遇到问题

    以下文档来自Django Postgres 文档 https docs djangoproject com en 4 1 ref databases postgresql notes我添加到我的settings py 在我设置的设置中 DA
  • 如何向 Jupyter (ipython) 笔记本自动添加扩展?

    我已经安装了扩展 calico document tools 我可以使用以下命令从 Jupyter 笔记本中加载它 javascript IPython load extensions calico document tools 如何为每个
  • 过滤给定范围内的坐标

    我有数百个带有地理位置的 out 文件 我将把它们批量导入到 SQLite 数据库中 但是 为了节省时间 我只会导入地理坐标在某些间隔内的线 文件是这样的 value value longitude latitude value value
  • 如何忽略 Sentry 捕获中的某些 Python 错误

    我已将 Sentry 配置为捕获 Django Celery 应用程序中的所有错误 它工作正常 但我发现一个令人讨厌的用例是当我必须重新启动我的 Celery 工作人员 PostgreSQL 数据库或消息服务器时 这会导致数千种各种 无法访
  • pip 升级到 pip 10.x.x 后解析需求文件的正确方法?

    所以今天我确实发现随着发布pip 10 x x the req软件包更改了其目录 现在可以在下面找到pip internal req 由于通常的做法是使用parse requirements功能在你的setup py从需求文件中安装所有依赖
  • Python httplib 和 POST

    我目前正在使用别人编写的一段代码 它用httplib向服务器发出请求 它以正确的格式提供所有数据 例如消息正文 标头值等 问题是 每次尝试发送 POST 请求时 数据都在那里 我可以在客户端看到它 但没有任何内容到达服务器 我已经阅读了库规
  • Spacy-nightly (spacy 2.0) 问题“thinc.extra.MaxViolation 大小错误”

    显然成功安装了 spacy nightly spacy nightly 2 0 0a14 和英语模型 en core web sm 后 我在尝试运行它时仍然收到错误消息 import spacy nlp spacy load en core
  • print() 函数的有趣/奇怪的机制

    我正在学习Python 我目前正在学习如何定义自己的函数 并且在尝试理解返回值和打印它之间的区别时遇到了一些困难 我读到的关于这个主题的描述对我来说不太清楚 所以我开始自己尝试 我想我现在已经明白了 如果我没记错的话 区别在于你可以传递 a
  • 合并共享属性的节点

    EDITED 我真的需要 Networkx graph 专家的帮助 假设我有以下数据框 我想将这些数据框转换为图表 然后我想根据描述和优先级属性将两个图映射到相应的节点 df1 From description To priority 10
  • 获取 Flask 中没有端口的请求主机名

    我刚刚设法使用 Flask 获取我的应用程序服务器主机名request host and request url root 但这两个字段都返回请求主机名及其端口 我想使用仅返回请求主机名的字段 方法 而无需进行字符串替换 如果有 没有 We

随机推荐

  • 电机控制基础——定时器基础知识与PWM输出原理

    单片机开发中 电机的控制与定时器有着密不可分的关系 无论是直流电机 步进电机还是舵机 都会用到定时器 比如最常用的有刷直流电机 会使用定时器产生PWM波来调节转速 通过定时器的正交编码器接口来测量转速等 本篇先介绍定时器的基础知识 然后对照
  • importing maven projects 9% 卡住

    导入一个maven工程后 一直显示 importing maven projects 9 解决办法 找到eclipse安装目录下的eclipse ini 在最后加入 vm JAVA HOME bin javaw exe 再次重启eclips
  • flutter 环形进度条组件CircularProgressIndicator、线性进度条组件LinearProgressIndicator

    环形进度条组件 不能放在ListView中 若不设置value 即无value参数 会一直加载动画 LinearProgressIndicator valueColor AlwaysStoppedAnimation Colors x 设置进
  • 如何进入mysql命令界面

    1 找到安装mysql安装路径 复制bin目录地址 eg D installmysqlin 2 进入cmd命令窗口 3 因为安装到D盘 进入D盘的盘符 输入D 直接输入cdD installmysqlin是无效的 4 进入bin目录 cdD
  • video.js 播放 rtsp、hls

    什么是HLS RTSP RTMP HLS HTTP Live Streaming 苹果公司提出的流媒体协议 直接把流媒体切片成一段段 信息保存到m3u列表文件中 可以将不同速率的版本切成相应的片 播放器可以直接使用http协议请求流数据 可
  • arduino+esp8266+onenet+mqtt_4G模块(EC20)连接MQTT服务器(EMQ X)

    上面的示意图也是这篇推文的主题 使用三个客户端EC20 ESP8266 MQTTX通过MQTT协议连接上我们搭建的EMQ X服务器最后完成消息的发布和订阅 概述 其中ESP826大家都有认识 那么相对陌生的EC20是一个4G模块 MQTTX
  • vika+obsidian快速进入一个研究领域

    目的是快速熟悉一个陌生的研究领域 写出文献综述 步骤 检索相关文献100篇以上 在vika中建表格 表头如下所示 阅读100篇论文的题目 关键词 摘要 填充vika表格 并找出需要精度的文献10篇左右 中文综述优先 精度10篇论文 并用ob
  • 《程序员面试宝典》第6章 宏和const

    一 用一个宏定义FIND求一个结构体struc里某个变量相对struc的偏移量 int a char b 20 double cc FIND student a 0 FIND student b 4 FIND student cc 4 20
  • vscode 如何自动补全react和jsx代码?

    补全react代码 文件 首选项 设置 打开之后 搜索emmet includeLanguages 添加javascript javascriptreact属性 重启即可生效 补全jsx代码 在原来的搜索位置 搜索emmet trigger
  • 《游戏测试精通》观后感

    第I部分 游戏测试简介 第一章 游戏测试的两条原则 1 不要恐慌 2 不要相信任何人 在第一章第一节的学习中 我了解到了 作为一名测试工作者 不论是不是新手还是一个资深的 老人 在处于以下场景时 都会或多或少的出现恐慌的情况 不熟悉环境或业
  • PCI-E

    PCI E 1 简介 PCI E PCI Express的所写 是最新的总线和接口标准 它原来的名称为 3GIO 是由英特尔提出的 很明显英特尔的意思是它代表着下一代I O接口标准 交由PCI SIG PCI特殊兴趣组织 认证发布后才改名为
  • 1. 两数之和 C++

    给定一个整数数组 nums 和一个整数目标值 target 请你在该数组中找出 和为目标值 target 的那 两个 整数 并返回它们的数组下标 你可以假设每种输入只会对应一个答案 但是 数组中同一个元素在答案里不能重复出现 你可以按任意顺
  • 一文搞懂OC门、OD门及其作用

    我们先给出OC门 OD门的定义 然后从原理出发 介绍OC门 OD门的作用 1 什么是OC门 OD门 OC门 Open Collector Gate 集电极开路门 如图1所示 当N1导通时 输出低电平 当N1截止时 输出高阻态 电路的一种输出
  • 火狐安装网页视频下载插件(Video DownloadHelper)

    Video DownloadHelper是一款以最简单的方式下载网页视频的chrome插件 基本上火狐浏览器能够加载出视频流并正常播放的视频该插件都能够抓取 可以说该插件对于网页视频下载还是十分快捷并且使用场景广泛的 本地安装后的版本 Vi
  • 字节数组的妙用

    在计算机高级语言中 字节属于最小单位 例如在Java中 int占用4个字节 long占用8个字节等 基本上所有基本类型 包括String 都可以转换成字节 那么这到底有何作用 本篇博客主要是记录了我使用字节数组的经验 希望可以给大家提供一些
  • volatile详解(任何人都能懂的那种)

    volatile 看了好多篇博客终于明白这个关键字到底是干嘛的 让我综合所有的博客写一篇大家都能理解它的博客 要点赞呦 volatile是一个类型修饰符 作用是作为指令关键字 一般都是和const对应 确保本条指令不会被编译器的优化而忽略
  • 怎么将将 Python 的安装目录添加到了系统的环境变量路径中

    要在 Windows 系统中将 Python 的安装目录添加到环境变量路径中 请按照以下步骤操作 1 打开 控制面板 2 选择 系统和安全 3 选择 系统 4 在 系统属性 中 选择 高级系统设置 5 在 高级 选项卡中 选择 环境变量 6
  • 【微信小程序】微信小程序支付功能实现

    1 前言 微信小程序支付 开启新一代便捷支付新时代 随着互联网技术的不断发展 微信小程序支付已经成为了人们日常生活中不可或缺的一部分 微信小程序是一种无需下载安装即可使用的应用 用户可以通过微信扫描或搜索关键词来打开并使用 而微信小程序支付
  • javascript面试题--持续更新

    前端HTML篇 前端CSS篇 Vue篇 TypeScript篇 React篇 微信小程序篇 前端面试题汇总大全 含答案超详细 HTML JS CSS汇总篇 持续更新 前端面试题汇总大全二 含答案超详细 Vue TypeScript Reac
  • pytorch量化库使用(2)

    FX Graph Mode量化模式 训练后量化有多种量化类型 仅权重 动态和静态 配置通过qconfig mapping prepare fx函数的参数 完成 FXPTQ API 示例 import torch from torch ao