Meta 内部都在用的 FX 工具大起底:利用 Graph Transformation 优化 PyTorch 模型

2023-11-14

PyTorch 中的 graph mode 在性能方面表示更为出色,本文介绍 Torch.FX 这个强大工具,可以捕捉和优化 PyTorch 程序 graph。

一、简介

PyTorch 支持两种执行模式:eager mode 和 graph mode。

eager mode 中,模型中的运算符在读取时会立即执行,它易于使用,对机器学习从业者更友好,因此被设置为默认的执行模式。

graph mode 中,运算符先被合成一个 graph,然后作为一个整体进行编译和执行,它的性能更高,因此在实际生产中大量使用。

具体来说,graph mode 支持算子融合,两个算子通过合并,可以降低或本地化内存读取以及内核启动总开销。

融合可以是横向 (horizontal) 的:采取应用于多个  operand 的单一操作(如 BatchNorm),并将这些  operand 合并到一个数组中。

融合也可以是纵向 (vertical) 的:将一个内核与另一个内核合并,后者需要使用第一个内核的输出(如 ReLU 后接卷积)。

Torch.FX(缩写为 FX)是一个公开可用的工具包,作为 PyTorch 软件包的一部分,支持 graph mode 的执行。它可以:

1. 从 PyTorch 程序中获取 graph

2. 允许开发者在获取的 graph 上编写 transformation

Meta 内部先前已经在用 FX 来优化生产模型 (production model) 的训练吞吐量 (training throughput)。本文将通过介绍 Meta 开发的基于 FX 的优化,来展示利用图结构转换 (graph transformation) 优化 PyTorch 部署模型性能的方法。

二、背景

embedding table 广泛存在于推荐系统中,本节将介绍 FX 和 embedding table 的背景知识。

2.1. FX 

图 1 是一个简单示例,演示了如何用 FX 转换  PyTorch 程序,它包含三个步骤:

  •  从程序中获取 graph

  • 修改 graph(在本例中,我们用 GELU 代替 RELU)

  • 从修改后的 graph 中生成一个新程序

图1:在 PyTorch 模块中用 GELU 取代 RELU 的 FX

FX API 为检查和转换 PyTorch 程序 graph 还提供了许多其他功能。

2.2. embedding table 

图2:批尺寸=1 的稀疏特征 embedding table 示意图

在推荐系统中,稀疏特征(例如,User ID,Story ID)由 embedding table 表示。

embedding table E 是一个 HxD 矩阵,其中 H 是哈希大小,D 是嵌入向量维度。E 的每一行都是一个浮点数向量。

feature hashing 的作用是将一个稀疏特征映射到 E的索引列表中,例如 [S1,S2,...,Sk],其中 0≤Si<H。它的输出值计算为 f(E[S1], E[S2],...,E[Sk]),其中 E[Si] 是 Si 行的向量,f  是池化函数,通常是 sum,average,max 三个函数之一。

为了充分利用 GPU,稀疏特征通常为批处理。批处理中的每个实体都有自己的索引列表。如果一个批次有 B 个实体,可以简单理解为一个表征有 B 个索引列表。

更为严谨的表示方法是将 B 个索引列表合并成一个索引列表,并添加一个索引长度的列表(该批中的每个实体都有一个长度 length)。

例如,如果一批包含 3 个实体,其索引列表如下:

  •  Entity 1: indices = [10, 20]

  •  Entity 2: indices = [5, 9, 77, 81]

  •  Entity 3: indices = [15, 20, 45]

则完整批尺寸的 indice 和 length 将是:

  •  Indices = [10, 20, 5, 9, 77, 81, 15, 20, 45]

  •  Lengths = [2, 4, 3]

而整个 batch 的 embedding table 查询,输出为是一个 BxD 矩阵。

三、3 种 FX Transformation 

PyTorch 更新了 3 个 FX transformation,以加速对 embedding table 的访问,本节将逐一介绍。

下文 3.1 关于将多个小输入张量结合成一个大张量的转换;3.2 关于将多个并行计算链融合成一个计算链的转换;3.3 关于将通信与计算重叠的转换。

 3.1 结合输入稀疏特征 

batch 中的每个输入稀疏特征,都可以表示为两个列表:一个索引列表和一个 B length 列表,其中 B 表示批尺寸。

在 PyTorch 中,这两个列表都可以以张量的形式存在。当 PyTorch 模型在 GPU 上运行时,embedding table 通常存储在 GPU 内存中(它更接近 GPU,读写带宽比 CPU 内存更高)。

需要使用输入稀疏特征时,两个张量都要先从 CPU 复制到 GPU。然而每个主机到设备的内存复制都需要启动内核,这对于实际的数据传输来说,会更加耗费时间。

如果一个模型使用了多个输入稀疏特征,这种复制可能成为性能瓶颈(例如,1000 个输入稀疏特征将需要从主机到设备复制 2000 个张量)。

一个减少主机到设备 memcpy 数量的优化方法,就是在多个输入稀疏特征发送到设备之前,先将其进行组合。

例如,给定以下三个输入特征:

  •  Feature_A: indices = [106, 211, 7], lengths = [2, 1]

  •  Feature_B: indices = [52, 498, 616, 870, 1013], lengths = [3, 2]

  •  Feature_C: indices = [2011, 19, 351, 790], lengths = [1, 3]

组合后的形式为:

Features_A_B_C: indices = [106, 211, 7, 52, 498, 616, 870, 1013, 2011, 19, 351, 790], lengths = [2, 1, 3, 2, 1, 3]

所以不需要从主机到设备复制 3x2=6 个张量,只需要复制 2 个张量。

图 3(b) 描述了这种优化的实现,它包含两个组件:

  •  CPU 端:输入 pipeline 被修改为将所有稀疏特征的 indices 组合成一个张量,所有 length 组合成另一个张量。然后将这两个张量复制到 GPU 上。

  •  GPU 端:使用 FX,在模型 graph 中插入一个Permute_and_Split 算子,从合并的张量中恢复单个特征 indices 和 length 张量,并将其发送至下游的相应节点。

 

优化前:两个张量都要从 CPU 复制到 GPU

优化后:将输入稀疏特征进行组合

3.2 从访问 embedding table 开始的计算链横向融合 

在一个生产模型中,每个 GPU 上有 10 个 embedding table 很常见。出于性能方面的考虑,对这些 table 的查询被分到一组,这样它们的输出就被串联在一个大张量中(见图 4(a)中的红色部分)。

为了对单个特征输出进行计算,使用 Split 算子将大张量分成 N 个小张量(其中 N 为特征的数量),然后将所需的计算应用于每个张量。

如图 4(a) 所示,应用于每个特征输出 O 的计算是Tanh(LayerNorm(O))。所有的计算结果都被串联成一个大的张量,然后传递给下游的算子(图 4(a) 中的 Op1)。

这里主要的 runtime cost 是 GPU 内核启动的开销。例如,图 4(a) 中的 GPU 内核的启动次数为 2*N+3(图中的每个椭圆都表示一个 GPU 内核)。这会影响性能,因为 LayerNorm 和 Tanh 在 GPU 上的执行时间,与它们的内核启动时间相比很短。

此外,Split 算子可能会创建一个额外的嵌入向量输出张量的副本,消耗额外的 GPU 内存。

用 FX 来实现一种叫做横向融合 (horizontal fusion) 的优化,可以大大减少 GPU 内核的启动次数(在这个例子中,优化后的 GPU 内核启动次数为 5,见图 4(b))。

使用 Add_middle_dim 算子代替显式 Split,将 shape 为 (B, NxD) 的 2D 嵌入张量重塑为 shape 为 (B, N, D) 的 3D 张量。接下来将一个单一的 LayerNorm 应用到它的最后一维。对 LayerNorm 的结果应用一个 Tanh。最后,用 Remove_middle_dim 算子将 Tanh 的结果恢复成 2D 张量。

由于 Add_middle_dim 和 Remove_middle_dim 只是重塑张量,没有创建额外的副本,所以也可以减少 GPU 内存的消耗。

优化前:所有输出被串联到一个大张量中

进行横向融合优化后

3.3 计算与通信间的重叠 (overlap) 

面向投产的推荐模型的训练,通常是在分布式 GPU 系统上完成的。由于每个 GPU 的设备内存容量不足以容纳模型中的所有 embedding table,因此需要将其分布在多个 GPU 上。

在训练步骤中,GPU 需要从其他 GPU 上的 embedding table 中读取/写入特征值。这被称为 all-to-all 通信,可能是影响性能的重要原因。

通过 FX 实现一个 transformation,可以将计算与 all-to-all 通信重叠。图 5(a) 显示了一个具备嵌入向量 table 访问 (EmbeddingAllToAll) 及其他算子的模型 graph 实例。如图 5(b) 所示,在没有任何优化的情况下,它们会在一个 GPU 流上顺序执行。

使用FX将 EmbeddingAllToAll 分成  EmbeddingAllToAll_Request和EmbeddingAllToAll_Wait,并在它们之间安排独立的算子。

图5:计算与通信的重叠

 3.4 总结 

表1:本节讨论的优化及解决的相应性能瓶颈

为了发现哪些模型会从这些 transformation 中受益,开发人员对 MAIProf 收集的运行在 Meta 数据中心的模型的性能数据进行分析。得出与 eager mode 相比,这些 transformation 在一组生产模型上实现了 2-3 倍的速度提升。

四、结语

从性能角度考量,PyTorch 中的 graph mode 比生产环境中使用的 eager mode 更受欢迎。FX 是一个强大的工具,可以捕捉和优化 PyTorch 程序 graph。本文展示了三种 FX transformation,用于优化 Meta 内部的生产推荐模型。

最后希望更多 PyTorch 开发者可以使用 graph transformation 来提升模型的性能。

—— 完 ——

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Meta 内部都在用的 FX 工具大起底:利用 Graph Transformation 优化 PyTorch 模型 的相关文章

随机推荐

  • python的四种输出形式_pandas实现导出数据的四种方式

    本文主要介绍了pandas导出数据到文件的四种方式 分享给大家 主要也是给自己留个笔记 具体如下 import pandas as pd import pymysql df pd DataFrame A 3 4 8 9 B 1 2 2 4
  • torch 指定显卡

    1 代码中指定 import os os environ CUDA VISIBLE DEVICES gpu ids 2 shell中指定 export CUDA VISIBLE DEVICES gpu ids python3 train p
  • matlab二元函数驻点,求matlab高手帮忙写个关于对二元函数积分的程序 - 程序语言 - 小木虫 - 学术 科研 互动社区...

    clear clc build matrice m1 0 8 0 1i 0 5i 0 3i 0 3 0 4i m2 0 1 0 05i 0 2i 0 0 4 0 9i S 1 0 6 0 4i differentiate symbolic
  • mybatis insert方法返回数据是什么

    正在学习mybatis 在salsession调用对应sql语句之后 会有一个返回值 看到有人说返回值是1为true 0为false 还有说是返回对应表的主键id的 其实都不对 后来 查阅资料发现这个值默认是 受影响的行数 而且过程中还遇到
  • 通过ThreadLocal和HandlerInterceptor实现java后台业务埋点日志功能

    目前公司的方案是用mdc来实现一个请求的业务数据埋点记录 但是mdc是map方式 需要手动设置key 而且每次都要手动clear 一是不方便管理 再者如果忘记clear会造成业务埋点数据混乱 所以有了想要把埋点数据字段统一封装的想法 这样方
  • 【DETR】DETR训练VOC数据集/自己的数据集

    训练DETR 一 数据准备 二 配置DETR 三 绘图 四 推理 五 一些小bug 1 取整问题 2 num class的设置问题 References 一 数据准备 DETR用的是COCO格式的数据集 如果要用DETR训练自己的数据集 直
  • 【Oracle】获取最近工作日及前N个工作日

    需求 日历表 TCALENDAR DATES 工作日flag 1 非工作日 0 取任一查询日期最近工作日 及最近工作日前n个工作日 日历表TCALENDAR DATES样式 SELECT T BASE DATE T CAL DAY C LA
  • 基于WR703N路由器的WIFI机器人

    可以说 wifi机器人是一个比较成熟作品了 特别是使用wr703制作wifi机器人的有很多例子 因为1 其体积小 2 实时获取视频相比STM32容易 STM32F1系列性能不够 使用OV系列的摄像头较为吃力 3 可以使用路由器连接外网 使用
  • Linux系统:常用服务端口

    目录 一 理论 1 端口分类 2 传输协议 3 常用端口 一 理论 1 端口分类 一个计算机最多有65535个端口 端口不能重复 Linux 只有 root 用户可以使用 1024 以下的端口 表1 端口分类 端口 范围 说明 公认端口 W
  • 12306验证码的一些思考

    12306的验证码长这个样子 让选择图片 看起来非常完美的图片验证码 比那些简单又没有实用的字母数字验证码组合强太多了 那些字母数字组合直接获取图片光学识别然后填进表单就可以攻破 我也想实现这样的 怎么去实现呢 设计一个简单点的吧 1 我先
  • JS(解构) 之数组和对象中提取数据总结

    解构含义 解构功能含义 从复杂数据类型中 数组或对象 中提取数据的过程 JS 解构 之数组 从数组中提取首个元素 方式一 基于数组下标提取元素 const names zzg zcx zcy const it names 0 console
  • CAPL编程实现诊断刷写,车联网FOTA流程自动化测试(代码篇)

    原创内容 转载请注明出处 接上篇 本文主要讲CAPL编程详细实现 软件环境CANoe 11 0 一 Simulation Setup 1 建模之前 首先创建一个 DBC文件 如果不会 可以用一个已有的DBC文件修改 新建待仿真的空节点 如下
  • Linux系统的启动流程

    一 开机启动流程图 第一步 开机自检就是开始工作之前先对自己的工具进行检查是否正常 BIOS就是主板上的一个自检程序 开机先对主板上自带的和外界的一些开机必备的设备进行检测 比如CPU 显卡 内存 硬盘等设备的自检过程就是自检 第二步 MB
  • 【斯坦福CS224W笔记之二】传统图机器学习的特征工程 — 节点

    Traditional Methods for ML on Graphs 是根据同济子豪兄学长的中文讲解做的笔记哦 感兴趣的话可以直接去b站观看详细视频 传送带 https github com TommyZihao zihao cours
  • Flask 框架

    目录 Flask介绍和安装 请求与响应 请求 响应 登录案例 配置文件写法 路由系统 路由写法 转换器 CBV session的使用和原理 flask session的使用 闪现flash 请求扩展 g对象 蓝图 小型蓝图 大型蓝图 数据库
  • 搭建AI智能语音外呼系统 智能语音外呼机器人

    随着人工智能技术的发展 近半年来涌现了大量基于人工智能的呼叫中心业务服务商和集成商 仅电销机器人这一个方向就至少有近百家公司正在推广运营 包括百度 讯飞 智齿 硅基 百应 箭鱼 容联等 商务上的需求非常强烈 整个市场都飞快地热闹起来 一套可
  • 小细节{变量名-枚举}

    一 类的变量名第一个字母一定要小写 eventType event type eventId 13 userId 45 openingFlag true Data TableName user activity AllArgsConstru
  • 基于matlab的车牌识别

    20221126 新增 首先说一下这个工程的思路 很多朋友妄想直接拿着工程用 那是不可能的 自己学去叭 我是先将车牌号预处理之后 整个图片干净一点之后 进行每个字符的切割 但是是很投机取巧的方法 是先切好第一个字符 再根据切割坐标 切割下一
  • 堆排序算法的具体分析和实现

    定义 堆就是完全二叉树的数据结构 堆排序是利用二叉树的孩子与双亲节点的比较来实现的排序方法 大顶堆 每个节点的值都大于或者等于它的左右子节点的值 小顶堆 每个节点的值都小于或者等于它的左右子节点的值 这里使用的是大顶堆 基本思想 堆排序的基
  • Meta 内部都在用的 FX 工具大起底:利用 Graph Transformation 优化 PyTorch 模型

    PyTorch 中的 graph mode 在性能方面表示更为出色 本文介绍 Torch FX 这个强大工具 可以捕捉和优化 PyTorch 程序 graph 一 简介 PyTorch 支持两种执行模式 eager mode 和 graph