pytorch每日一学24(torch.quantize_per_tensor()、torch.quantize_per_channel())使用映射过程将tensor进行量化

2023-11-09

第24个方法

torch.quantize_per_tensor(input, scale, zero_point, dtype) → Tensor


torch.quantize_per_channel(input, scales, zero_points, axis, dtype) → Tensor
  • 以上两个方法是将将浮点张量转换为具有给定比例和零点的量化张量。

Quantization(量化)介绍

  量化是指用于执行计算并以低于浮点精度的位宽存储张量的技术。量化模型对张量使用整数而不是浮点值执行部分或全部运算。这使得可以采用更紧凑的模型表示,并可以在许多硬件平台上使用高性能矢量化操作。与典型的FP32型号相比,PyTorch支持INT8量化,从而可将模型大小减少4倍,并将内存带宽要求减少4倍。与FP32计算相比,对INT8计算的硬件支持通常快2到4倍。

  • 量化主要是一种加速推理的技术,并且量化算子仅支持前向操作,只有在进行forward的时候才可以采用量化操作进行加速,在进行backward时不可以进行使用量化操作来进行加速。

  PyTorch支持多种量化深度学习模型的方法。在大多数情况下,该模型在FP32中训练,然后将模型转换为INT8。此外,PyTorch还支持量化意识训练,该训练使用伪量化模块对前向和后向传递中的量化误差进行建模。注意,整个计算是在浮点数中进行的。在量化意识训练结束时,PyTorch提供转换功能,将训练后的模型转换为较低的精度。


  在较低级别,PyTorch提供了一种表示量化张量并对其执行操作的方法。它们可用于直接构建以较低的精度执行全部或部分计算的模型。提供了更高级别的API,这些API包含了将FP32模型转换为较低精度并减少精度损失的典型工作流程。
  • 总的来说,量化就是将机器中使用的浮点tensor转化为整数tensor使得计算变得更加的快,也可以减少资源的消耗,而且这种方法不会使准确度下降很多。如下表:
    在这里插入图片描述
    在这里插入图片描述

进行量化的方法有以下三个大类(存在于torch.quantization名称空间中),今天讲第一个大类:在这里插入图片描述


开始介绍方法

  • torch.quantize_per_tensor()是按照tensor来进行转化的,每个tensor中所有数据进行一样的操作

  • torch.quantize_per_channel()是对每个channel都有一组对应的缩放和零点,对每个channel进行不同的变化。

  • torch.quantize_per_tensor()官方例子:

>>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8)
tensor([-1.,  0.,  1.,  2.], size=(4,), dtype=torch.quint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10)
>>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8).int_repr()
tensor([ 0, 10, 20, 30], dtype=torch.uint8)

参数介绍:

  • input:要量化的tensor。
  • scale:应用在量化公式的缩放大小。
  • zero_point:以整数值表示的偏移量,该值映射为浮点数零。
  • dtype:返回张量的所需数据类型。必须是量化的dtypes之一:torch.quint8,torch.qint8,torch.qint32

对张量进行量化的时候使用的如下公式(Q为量化后的张量, x为输入):
在这里插入图片描述

函数映射讲解

  • 这个地方很多文章的讲的不是很细致,所以我刚开始只看他们的文章也没有怎么看懂,然后自己动手写了一下又想了一下,才搞懂了是怎么回事,果然还是得动手写啊。
    观看如下例子:
    在这里插入图片描述
  • 会发现根本没有什么变化,那么到底进行了这个转化有什么意义?其实是有意义的,我们再看下面的例子:
    在这里插入图片描述
  • 会发现100变成了24.5,而其他的并没有改变。其实我们这里的a中的元素,就拿1来说吧,它缩放过后,也就是 1 0.1 + 10 = 20 \frac{1}{0.1}+10=20 0.11+10=20,但是这里并没有这个结果,我们使用如下方法:
    在这里插入图片描述
  • 发现上面的数经过我们的公式进行缩放以后,确实得到下面这几个数,其实tensor.int_repr()将给定量化的Tensor(注意只能是量化的tensor),此方法返回以uint8_t作为数据类型的CPU Tensor,该数据类型存储给定Tensor的基础uint8_t值。
  • 所以其实我们上面使用torch.quantize_per_tensor()已经进行了操作,它在内部存储中已经是我们想要的形式了,只不过是显示的时候还是按照原来的样子显示而已。

那么到底表示什么意思呢?

  • 上面的b中的元素(不是a)表示了离中心点(例子中为10)的距离,注意此距离还要进行缩放,例如b中第二个元素为1,它表示了距离中心点10为10,因为它还要进行除以0.1的缩放操作,这样就能理解为什么存放的还是1。
  • 但是a中的100为什么变成了b中的24.5?因为8位无符号数范围为 0 ∼ 255 0\sim255 0255,最大为255(即b.int_repr()中最大的值为255),距离中心点(10)最大的距离只能是245,然后还要乘以0.1,所以就是24.5。同理,8位无符号数最小的数是0,所以距离10最远是-10即(b.int_repr()中最小的值是-10),再乘以缩放的话,b中的最小值是-1,如下所示:
    在这里插入图片描述
  • 所以此方法的作用就是存放了表示离中心点的距离(截取最大最小,中间的保留),注意还要进行缩放。

理解了上面那个,那其实torch.quantize_per_channel就很好理解了,对每个channel指定缩放和中心,并且分别进行这样的处理与变化。如下所示:

>>> x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
>>> torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8)
tensor([[-1.,  0.],
        [ 1.,  2.]], size=(2, 2), dtype=torch.quint8,
       quantization_scheme=torch.per_channel_affine,
       scale=tensor([0.1000, 0.0100], dtype=torch.float64),
       zero_point=tensor([10,  0]), axis=0)
>>> torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8).int_repr()
tensor([[  0,  10],
        [100, 200]], dtype=torch.uint8)

参数介绍:

  • input:要量化的tensor。
  • scales:浮点数要使用的一维张量,大小应与input.size(axis)相匹配。
  • zero_points:要使用的偏移量的整数1D张量,大小应与input.size(axis)相匹配。
  • axis:指定在哪个维度使用per-channel量化。
  • dtype:返回张量的所需数据类型。必须是量化的dtypes之一:torch.quint8,torch.qint8,torch.qint32
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch每日一学24(torch.quantize_per_tensor()、torch.quantize_per_channel())使用映射过程将tensor进行量化 的相关文章

  • 是否有解决方法可以通过 CoinGecko API 安全检查?

    我在工作中运行我的代码 一切都很顺利 但在不同的网络 家庭 WiFi 上 我不断收到403访问时出错CoinGecko V3 API https www coingecko com api documentations v3 可以观察到 在
  • 为什么从 Pandas 1.0 中删除了日期时间?

    我在 pandas 中处理大量数据分析并每天使用 pandas datetime 最近我收到警告 FutureWarning pandas datetime 类已弃用 并将在未来版本中从 pandas 中删除 改为从 datetime 模块
  • Django 的内联管理:一个“预填充”字段

    我正在开发我的第一个 Django 项目 我希望用户能够在管理中创建自定义表单 并向其中添加字段当他或她需要它们时 为此 我在我的项目中添加了一个可重用的应用程序 可在 github 上找到 https github com stephen
  • Pandas/Google BigQuery:架构不匹配导致上传失败

    我的谷歌表中的架构如下所示 price datetime DATETIME symbol STRING bid open FLOAT bid high FLOAT bid low FLOAT bid close FLOAT ask open
  • 删除flask中的一对一关系

    我目前正在使用 Flask 开发一个应用程序 并且在删除一对一关系中的项目时遇到了一个大问题 我的模型中有以下结构 class User db Model tablename user user id db Column db String
  • 使用 kivy textinput 的 'input_type' 属性的问题

    您好 我在使用 kivy 的文本输入小部件的 input type 属性时遇到问题 问题是我制作了两个自定义文本输入 其中一个称为 StrText 其中设置了 input type text 然后是第二个文本输入 名为 NumText 其
  • 独立滚动矩阵的行

    我有一个矩阵 准确地说 是 2d numpy ndarray A np array 4 0 0 1 2 3 0 0 5 我想滚动每一行A根据另一个数组中的滚动值独立地 r np array 2 0 1 也就是说 我想这样做 print np
  • 使用字典映射数据帧索引

    为什么不df index map dict 工作就像df column name map dict 这是尝试使用index map的一个小例子 import pandas as pd df pd DataFrame one A 10 B 2
  • Pandas Merge (pd.merge) 如何设置索引和连接

    我有两个 pandas 数据框 dfLeft 和 dfRight 以日期作为索引 dfLeft cusip factorL date 2012 01 03 XXXX 4 5 2012 01 03 YYYY 6 2 2012 01 04 XX
  • 如何在 Python 中解析和比较 ISO 8601 持续时间? [关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 我正在寻找一个 Python v2 库 它允许我解析和比较 ISO 8601 持续时间may处于不同单
  • 如何使用 pybrain 黑盒优化训练神经网络来处理监督数据集?

    我玩了一下 pybrain 了解如何生成具有自定义架构的神经网络 并使用反向传播算法将它们训练为监督数据集 然而 我对优化算法以及任务 学习代理和环境的概念感到困惑 例如 我将如何实现一个神经网络 例如 1 以使用 pybrain 遗传算法
  • 不同编程语言中的浮点数学

    我知道浮点数学充其量可能是丑陋的 但我想知道是否有人可以解释以下怪癖 在大多数编程语言中 我测试了 0 4 到 0 2 的加法会产生轻微的错误 而 0 4 0 1 0 1 则不会产生错误 两者计算不平等的原因是什么 在各自的编程语言中可以采
  • 仅第一个加载的 Django 站点有效

    我最近向 stackoverflow 提交了一个问题 标题为使用mod wsgi在apache上多次请求后Django无限加载 https stackoverflow com questions 71705909 django infini
  • Pandas 将多行列数据帧转换为单行多列数据帧

    我的数据框如下 code df Car measurements Before After amb temp 30 268212 26 627491 engine temp 41 812730 39 254255 engine eff 15
  • 如何在 pygtk 中创建新信号

    我创建了一个 python 对象 但我想在它上面发送信号 我让它继承自 gobject GObject 但似乎没有任何方法可以在我的对象上创建新信号 您还可以在类定义中定义信号 class MyGObjectClass gobject GO
  • 将 Python 中的日期与日期时间进行比较

    所以我有一个日期列表 datetime date 2013 7 9 datetime date 2013 7 12 datetime date 2013 7 15 datetime date 2013 7 18 datetime date
  • Scipy Sparse:SciPy/NumPy 更新后出现奇异矩阵警告

    我的问题是由大型电阻器系统的节点分析产生的 我基本上是在设置一个大的稀疏矩阵A 我的解向量b 我正在尝试求解线性方程A x b 为了做到这一点 我正在使用scipy sparse linalg spsolve method 直到最近 一切都
  • 在 JavaScript 函数的 Django 模板中转义字符串参数

    我有一个 JavaScript 函数 它返回一组对象 return Func id name 例如 我在传递包含引号的字符串时遇到问题 Dr Seuss ABC BOOk 是无效语法 I tried name safe 但无济于事 有什么解
  • 更改 Tk 标签小部件中单个单词的颜色

    我想更改 Tkinter 标签小部件中单个单词的字体颜色 我知道可以使用文本小部件来实现与我想要完成的类似的事情 例如使单词 YELLOW 显示为黄色 self text tag config tag yel fg clr yellow s
  • Kivy - 单击按钮时编辑标签

    我希望 Button1 在单击时编辑标签 etykietka 但我不知道如何操作 你有什么想法吗 class Zastepstwa App def build self lista WebOps getList layout BoxLayo

随机推荐

  • 2020.11.1

    1 登录业务的完善 1 1后端控制页面跳转 if rs next System out println 登陆成功 request setAttribute name username request getRequestDispatcher
  • JAVA经典兔子问题

    有一对兔子 从出生第三个月起每个月都生一对兔子 小兔子长到第三个月后 每个月又生一对兔子 假如兔子都不死 问M个月时兔子的数量 很经典的斐波那契数列问题 记得第一次看到这道题是在一次比赛中 当时并不知道斐波那契数列 但是列出来几个月的兔子数
  • SQL 数据库中如何自动生成订单号

    有一张表TAB1 字段num num里有很多数字 我想从1开始 查到里面缺少的最小的一个数字 例如 4 5 6 8 9 11 12 13 这样的话我想要的结果是1 1 2 3 4 5这样的话 我想要的结果是6 其实利用正常排序的找第一个不正
  • open3d操作.ply文件(点云)

    读取 ply文件 import open3d as o3d pcd o3d io read point cloud ply path format ply ppoints np asarray pcd points pcolors np a
  • Ubuntu20.04编译安装openpose使用pythonAPI

    目录 项目地址 环境 准备 开始编译 项目地址 https github com CMU Perceptual Computing Lab openpose 环境 系统 ubuntu20 04 cuda 11 2 GPU 3090 2 Dr
  • 第一款中国人自主研发的普及型计算机高级编程语言

    最近有了比较大的技术突破 可以实现快速的开发环境了 我之前的计划一直是解析脚本来实现迈欧网的开发环境 有了这个技术 虽然是高级语言 但是却不会丧失性能 达到C 等语言的速度 甚至更快 希望朋友们支持我 你们的支持是我不间断开发此产品的动力
  • muduo启程

    muduo启程 muduo 是一个基于 Reactor 模式的现代 C 网络库 它采用非阻塞 IO 模型 基于事件驱动和回调 原生支持多核多线程 适合编写 Linux 服务端多线程网络应用程序
  • 使用广度优先搜索查找图中路径(java)

    package breadthfirstpaths import edu princeton cs algs4 Graph import edu princeton cs algs4 Queue import edu princeton c
  • Android四大组件之service(二)

    在 Android四大组件之service 一 文中我们讲到了 service 的 基本概念 和 startService 启动方式 stopService 不过这种方式是有个缺点 我们无法调用 FirstService 类里面的方法 这个
  • webStrom智能提示忽略首字母大小写问题

    Settings gt Editor gt Ceneral gt Code Completion gt Case sensitive completion 设置为None
  • vuex中的mutations的两种调用方法

    直接通过 store commit调用
  • Ubuntu14.04 安装ffmpeg

    一 xvid x264 ffmpeg源码下载 链接 https pan baidu com s 13phSFrLqkGrKDGF3 a2cSA 提取码 ls2s 二 安装 1 xvid tar zxvf xvidcore 1 3 3 tar
  • 一文带你看懂Spring事务!

    点击上方 方志朋 选择 设为星标 做积极的人 而不是积极废人 前言 Spring事务管理我相信大家都用得很多 但可能仅仅局限于一个 Transactional注解或者在XML中配置事务相关的东西 不管怎么说 日常可能足够我们去用了 但作为程
  • 677. 键值映射

    实现一个 MapSum 类 支持两个方法 insert 和 sum MapSum 初始化 MapSum 对象 void insert String key int val 插入 key val 键值对 字符串表示键 key 整数表示值 va
  • 面试之计算机网络

    计算机网络 1 路由选择协议 常见的路由选择协议有 RIP协议 OSPF协议 RIP协议 底层是贝尔曼福特算法 它选择路由的度量标准 metric 是跳数 最大跳数是15跳 如果大于15跳 它就会丢弃数据包 OSPF协议 底层是迪杰斯特拉算
  • IDEA 设置默认Maven的路径

    文件 新项目设置 构建工具 Maven 修改主路径
  • linux调整queue_depth,linux – 无法编辑/ sys / block / sdX / device / queue_depth文件

    我正在尝试使用以下命令增加SSD的队列深度值 echo 64 gt sys block sda device queue depth 但是我收到以下错误 bash echo write error Invalid argument 我尝试使
  • STM32CubeIDE HAL库操作IIC (一)配置篇

    目录 一 MX配置 使能中断 可选 DMA设置 可选 二 生成的代码 三 IIC通信的三种方式 Polling IT DMA 代码源自官方例程 1 Polling 常用 2 IT 开启中断 接收到数据时会调用回调函数 3 DMA模式 回调函
  • Qt 如何使用正则表达式 正则表达式 密码 email

    Qt 正则表达式 regular expression 详细用法查看此博客 https blog csdn net dongdong csdn article details 78574168 QRegExp regExpPsw 正则表达式
  • pytorch每日一学24(torch.quantize_per_tensor()、torch.quantize_per_channel())使用映射过程将tensor进行量化

    第24个方法 torch quantize per tensor input scale zero point dtype Tensor torch quantize per channel input scales zero points