nvidia训练深度学习模型利器apex使用解读

2023-11-17

本文参考:

英伟达(NVIDIA)训练深度学习模型神器APEX使用指南_咆哮的阿杰的博客-CSDN博客_apex 英伟达

Pytorch混合精度(FP16&FP32)(AMP)/半精度 训练(二) —— 代码示例 apex pytorch_hxxjxw的博客-CSDN博客_apex pytorch

目录

 一、背景

二、apex介绍

三、apex配置

四、代码实现

1、三行代码示例

2、opt_level参数设置

3、测试代码(不带amp)

4、测试代码(带amp)

5、swin-transformer算法添加amp实战 

五、溢出问题


 一、背景

gpu显存不大,很多模型没法跑,不能用很大的batch size等导致loss没法降低。使用apex工具可以从中解脱出来。

二、apex介绍

apex是nvidia开源的,完美支持pytorch框架,用于改变数据格式来减小模型显存占用的工具

其中最有价值的是amp(Automatic Mixed Precision),将模型的大部分操作都用float16数据类型替代,一些特别操作仍然使用float32.

并且用户仅仅通过三行代码即可完美将自己的训练代码迁移到该模型。

实验证明,使用float16作为大部分操作的数据类型,并没有降低参数,在一些实验中,反而由于可以增大batch size,带来精度上的提升,以及训练速度上的提升。

它号称能够在不降低性能的情况下,将模型训练的速度提升2~4倍,训练显存消耗减少为之前的一半。

三、apex配置

见:windows11安装apex工具_benben044的博客-CSDN博客

四、代码实现

1、三行代码示例

from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") 
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()

2、opt_level参数设置

只有一个opt_level需要用户自行配置

  • O0:纯FP32训练,可以作为accuracy的baseline
  • O1:混合精度训练(推荐使用),根据黑白名单自动决定使用FP16(GEMM,卷积)还是FP32(Softmax)进行计算
  • O2:"几乎FP16"混合精度训练,不存在黑白名单,除了Batch Norm,几乎都是用FP16计算
  • O3:纯FP16训练,很不稳定,但是可以作为speed的baseline。

3、测试代码(不带amp)

import torch

N, D_in, D_out = 64, 1024, 512
x = torch.randn(N, D_in, device='cuda')
y = torch.randn(N, D_out, device='cuda')
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for _ in range(1000):
    y_pred = model(x)
    loss = torch.nn.functional.mse_loss(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

4、测试代码(带amp)

import torch
from apex import amp

N, D_in, D_out = 64, 1024, 512
x = torch.randn(N, D_in, device='cuda')
y = torch.randn(N, D_out, device='cuda')
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
for _ in range(1000):
    y_pred = model(x)
    loss = torch.nn.functional.mse_loss(y_pred, y)
    optimizer.zero_grad()
    with amp.scale_loss(loss, optimizer) as scaled_loss:
        scaled_loss.backward()
    optimizer.step()

5、swin-transformer算法添加amp实战 

以上总共3处修改点。

测试结果:2g显存的batch_size从4增加到了6,1.5倍,似乎没有想象中那么多。

五、溢出问题

因为float16保存数据位数少了,能保存数据的上限和下限的绝对值也小了。

如果我们在处理分割类问题,需要用到一些涉及到求和的操作,如sigmoid,softmax,这些操作都涉及到求和。

分割问题特征图很大,求个sigmoid可能会导致数据溢出,得到错误的结果。

所以针对这些操作,仍然使用float32作为数据格式。

修改方式:仅需在模型定义中,在构造函数__init__中的某一个位置,加上下面这段:

from apex import amp
class xxxNet(Module):
	def __init__(using_map=False)
		...
		...
		if using_amp:
		     amp.register_float_function(torch, 'sigmoid')
		     amp.register_float_function(torch, 'softmax')

用register_float_function指明后面的函数需要使用float类型,注意第二实参是string类型。

和register_float_function相似的注册函数还有:

  • amp.register_half_function(module, function_name)
  • amp.register_float_function(module, function_name)
  • amp.register_promote_function(module, function_name)

需要在使用amp.initialize之前使用注册函数,所以最号的位置就放在模型的构造函数中。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

nvidia训练深度学习模型利器apex使用解读 的相关文章

  • 达梦数据库使用安装用户打开图形化工具显示无权限

    在x86虚拟机下 使用达梦数据库安装用户安装数据库后 经常需要使用安装用户打开诸如manager console等图形化管理工具 这时候经常遇到安装用户没有权限执行图形化界面的打开脚本 如下图 dmdba为安装数据库的用户 这实际上是dmd
  • 在C++泛型编程中如何只特化类的某个成员函数

    我们知道在C 模板编程中如果我们特化或是偏特化某个模板类 我们需要重写整个模板类中的所有函数 但是这些代码通常是非常相似的 甚至在某些情况下可能只有一两个函数会不一样 其他函数都是一样的 在这种情况下 同时存在多份相同的代码 对我们维护这些
  • 两个无序单链合并成一个有序单链表

    解题思路 两个无序链表先转换成两个有序单链表 两个有序单链表合并成一个有序单链表 代码 import java util 链表 class Node int val Node next public Node int val this va
  • iframe height 100% 问题

    iframe height 100 问题 最近 利用 MapGuide 技术开发一个 WebGIS 的应用程序 其中用到了 标签 可是当我调试运行的时候 其 width 100 生效了 但 height 100 就无效 无论用 JavaSc

随机推荐

  • 人工智能发展月报(2022年11月)

    本期导读 11月 人工智能业界热度较上月降温 共计发生576篇新闻 188个事件 热度总体趋势如下图所示 本月事件较多与世界互联网大会相关 期间多个会议活动及成果发布受到业内持续高度关注 此外 工信部等五部门发布的 虚拟现实与行业应用融合发
  • Liunx下pip3换源(最详细)

    在使用python时我们经常会安装各种包 我们一般安装的方式都是pip3 install xx模块 但是pip3默认源https pypi org 安装的过程非常慢 可能都是几k几k的 有时安装这安装着 直接error了 还有一种情况是直接
  • 微信公众号订阅消息

    1 官网介绍 功能介绍 微信开放文档 订阅通知是一个用户主动订阅 服务号按需下发的通知能力 使用过程请遵守 微信公众平台服务协议 微信公众平台运营规范 如有疑问 可在微信开放社区反馈 设置订阅功能 服务号可以在图文消息 网页等场景设置订阅功
  • iOS基本内存管理:autorelease和autoreleasepool

    在内存管理的Objective C代码里 一个Cocoa对象存在于一个生命周期 有明确的阶段 它被创建 初始化 并使用 也就是 其它对象发送消息给它 它还可能会被保留 拷贝 或压缩 并最终被释放和销毁 AD 1 autorelease 基本
  • ChatGPT 和爬虫有什么区别?

    ChatGPT是一种基于人工智能的对话模型 它通过训练大量的文本数据来生成自然语言回复 它可以用于实现智能对话系统 能够理解用户的输入并生成相应的回复 ChatGPT的目标是模拟人类对话 使得对话更加流畅和自然 而爬虫是一种用于自动化地从互
  • 算法 - 基数排序(Radix Sort)

    基数排序非常适合用于整数排序 尤其是非负整数 因此只演示对非负整数进行基数排序 执行流程 一次对个位数 十位数 百位数 千位数 万位数 进行排序 从低位到高位 个位数 十位数 百位数的取值范围都是固定的0 9 可以使用计数排序对它们进行排序
  • QT_6(信号连接信号、Lambda表达式)

    信号连接信号 运行代码 修改mywidget cpp文件如下 这是窗口界面 include mywidget h include
  • 关于建筑物半自动化提取方法的总结

    基于边界 基于边界的交互式提取方法要求用户指定目标边界的少量关键点或大概位置 然后基于目标边界强度和连续性等特征 对目标的边界进行准确跟踪 常见的基于边界方法是Snake算法和智能剪刀 Intelligent Scissors 基于边界方法
  • [附源码]计算机毕业设计社区生活废品回收APPSpringboot程序

    项目运行 环境配置 Jdk1 8 Tomcat7 0 Mysql HBuilderX Webstorm也行 Eclispe IntelliJ IDEA Eclispe MyEclispe Sts都支持 项目技术 SSM mybatis Ma
  • 微信支付 api v3 支付通知 异步 验签失败 PHP

    微信支付v3 异步验签失败 此处我们接收参数 报文主体 一般是通过框架 自带的request接收 例如TP6 this gt request gt param 这里如果使用此接收方式在进行json转换验签会失败 我们需要用原生的接收方式 f
  • python之pandas简单介绍及使用(一)

    一 Pandas简介 1 Python Data Analysis Library 或 pandas 是基于NumPy 的一种工具 该工具是为了解决数据分析任务而创建的 Pandas 纳入了大量库和一些标准的数据模型 提供了高效地操作大型数
  • JVM系列(三) JVM垃圾判断及强引用关系

    1 判断垃圾对象 如何判断该对象是垃圾 或者该对象要被回收 引用计数法 在对象中添加一个引用计数器 每当有一个地方引用它时 计数器值就 1 当引用失效时 计数器值就 1 任何时刻计数器为 0 的对象就是没人用的 那么就要被回收 优点是原理简
  • Game101课程笔记_lecture10_几何

    Game101课程笔记 lecture10 几何 1 纹理应用 1 环境光 环境贴图 2 凹凸贴图 法线贴图 1 Bump Mapping 3 位移贴图 4 三维纹理 5 环境光遮蔽 6 体渲染3D Texture and Volume R
  • Bookface(中位数,保序回归)

    include
  • int 和 Integer 作为接收参数类型,参数长度不能大于10?

    今天的博客主题 Java开发路上的小坑坑 int 和 Integer 作为接收参数类型 参数长度不能大于10 int 和 Integer 作为接收参数类型 参数长度不能大于10 What 就问问小菜鸟惊讶不惊讶 大佬略过 public st
  • Shell中的$0、$1、$2的含义

    0 就是编写的shell脚本本身的名字 1 是在运行shell脚本传的第一个参数 2 是在运行shell脚本传的第二个参数 如 新建了一个shell脚本 test sh bin sh echo shell脚本名称 0 echo 传到shel
  • three.js引用FontLoader()报错Unexpected token < in JSON at position 1

    使用three js的FontLoader 时 总是报错 文件也是正常的引入的json文件 但是还是报错 后来各自百度发现是文件路径的问题 报错时我使用的是相对路径 const loader new THREE FontLoader con
  • background 背景属性详解

    background 背景属性 我们知道元素有前景色color 与之对应的还有背景色 通过background我们可以为元素添加实色 background color 和任意多个背景图片 background image css 背景常见属
  • 计算机网络(二):TCP篇

    文章目录 1 TCP头部包含哪些内容 2 为什么需要 TCP 协议 TCP 工作在哪一层 3 什么是 TCP 4 什么是 TCP 连接 5 如何唯一确定一个 TCP 连接呢 6 UDP头部大小是多少 包含哪些内容 7 TCP与UDP的区别
  • nvidia训练深度学习模型利器apex使用解读

    本文参考 英伟达 NVIDIA 训练深度学习模型神器APEX使用指南 咆哮的阿杰的博客 CSDN博客 apex 英伟达 Pytorch混合精度 FP16 FP32 AMP 半精度 训练 二 代码示例 apex pytorch hxxjxw的