多gpu训练梯度如何计算,求和是否要求平均

2023-11-06

作者:智星云服务
链接:https://www.zhihu.com/question/271226455/answer/1521784627
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
 

作者:itsAndy
链接:https://www.zhihu.com/question/271226455/answer/1411456425
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

还是推荐使用智星云吧,有多卡的服务器配置,环境都是配置好的,用起来很方便顺手。
关于多gpu训练,tf并没有给太多的学习资料,比较官方的只有
:tensorflow-models/tutorials/image/cifar10/cifar10_multi_gpu_train.py
但代码比较简单,只是针对cifar做了数据并行的多gpu训练,利用到的layer、activation类型不多,针对更复杂网络的情况,并没有给出指导。
一、思路
单GPU时,思路很简单,前向、后向都在一个GPU上进行,模型参数更新时只涉及一个GPU。多GPU时,有模型并行和数据并行两种情况。模型并行指模型的不同部分在不同GPU上运行。数据并行指不同GPU上训练数据不同,但模型是同一个(相当于是同一个模型的副本)。在此只考虑数据并行,这个在tf的实现思路如下:
模型参数保存在一个指定gpu/cpu上,模型参数的副本在不同gpu上,每次训练,提供batch_size*gpu_num数据,并等量拆分成多个batch,分别送入不同GPU。前向在不同gpu上进行,模型参数更新时,将多个GPU后向计算得到的梯度数据进行平均,并在指定GPU/CPU上利用梯度数据更新模型参数。
假设有两个GPU(gpu0,gpu1),模型参数实际存放在cpu0上,实际一次训练过程如下图所示:
 


二、tf代码实现
大部分需要修改的部分集中在构建计算图上,假设在构建计算图时,数据部分基于tensorflow1.4版本的dataset类,那么代码要按照如下方式编写: next_img, next_label = iterator.get_next()
image_splits = tf.split(next_img, num_gpus) label_splits = tf.split(next_label, num_gpus) tower_grads = [] tower_loss = [] counter = 0 for d in self.gpu_id: with tf.device('/gpu:%s' % d): with tf.name_scope('%s_%s' % ('tower', d)): cross_entropy = build_train_model(image_splits[counter], label_splits[counter], for_training=True) counter += 1 with tf.variable_scope("loss"): grads = opt.compute_gradients(cross_entropy) tower_grads.append(grads) tower_loss.append(cross_entropy) tf.get_variable_scope().reuse_variables() mean_loss = tf.stack(axis=0, values=tower_loss) mean_loss = tf.reduce_mean(mean_loss, 0) mean_grads = util.average_gradients(tower_grads) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = opt.apply_gradients(mean_grads, global_step=global_step)

第1行得到image和对应label
第2-3行对image和label根据使用的gpu数量做平均拆分(默认两个gpu运算能力相同,如果gpu运算能力不同,可以自己设定拆分策略)
第 4-5行,保存来自不同GPU计算出的梯度、loss列表
第7-16行,开始在每个GPU上创建计算图,最重要的是14-16三行,14,15把当前GPU计算出的梯度、loss值append到列表后,以便后续计算平均值。16行表示同名变量将会复用,这个是什么意思呢?假设现在gpu0上创建了两个变量var0,var1,那么在gpu1上创建计算图的时候,如果还有var0和var1,则默认复用之前gpu0上的创建的那两个值。
第18-20行计算不同GPU获取的grad、loss的平均值,其中第20行使用了cifar10_multi_gpu_train.py中的函数。
第23行利用梯度平均值更新参数。
注意:上述代码中,所有变量(vars)都放在了第一个GPU上,运行时会发现第一个GPU占用的显存比其他GPU多一些。如果想把变量放在CPU上,则需要在创建计算图时,针对每层使用到的变量进行设备指定,很麻烦,所以建议把变量放在GPU上。
单机多GPU训练
先简单介绍下单机的多GPU训练,然后再介绍分布式的多机多GPU训练。 单机的多GPU训练, tensorflow的官方已经给了一个cifar的例子,已经有比较详细的代码和文档介绍, 这里大致说下多GPU的过程,以便方便引入到多机多GPU的介绍。 单机多GPU的训练过程:

  1. 假设你的机器上有3个GPU;
  2. 在单机单GPU的训练中,数据是一个batch一个batch的训练。 在单机多GPU中,数据一次处理3个batch(假设是3个GPU训练), 每个GPU处理一个batch的数据计算。
  3. 变量,或者说参数,保存在CPU上
  4. 刚开始的时候数据由CPU分发给3个GPU, 在GPU上完成了计算,得到每个batch要更新的梯度。
  5. 然后在CPU上收集完了3个GPU上的要更新的梯度, 计算一下平均梯度,然后更新参数。
  6. 然后继续循环这个过程。

通过这个过程,处理的速度取决于最慢的那个GPU的速度。如果3个GPU的处理速度差不多的话, 处理速度就相当于单机单GPU的速度的3倍减去数据在CPU和GPU之间传输的开销,实际的效率提升看CPU和GPU之间数据的速度和处理数据的大小。
通俗解释
写到这里觉得自己写的还是不同通俗易懂, 下面就打一个更加通俗的比方来解释一下:
老师给小明和小华布置了10000张纸的乘法题并且把所有的乘法的结果加起来, 每张纸上有128道乘法题。 这里一张纸就是一个batch, batch_size就是128. 小明算加法比较快, 小华算乘法比较快,于是小华就负责计算乘法, 小明负责把小华的乘法结果加起来 。 这样小明就是CPU,小华就是GPU.
这样计算的话, 预计小明和小华两个人得要花费一个星期的时间才能完成老师布置的题目。 于是小明就招来2个算乘法也很快的小红和小亮。 于是每次小明就给小华,小红,小亮各分发一张纸,让他们算乘法, 他们三个人算完了之后, 把结果告诉小明, 小明把他们的结果加起来,然后再给他们没人分发一张算乘法的纸,依次循环,知道所有的算完。
这里小明采用的是同步模式,就是每次要等他们三个都算完了之后, 再统一算加法,算完了加法之后,再给他们三个分发纸张。这样速度就取决于他们三个中算乘法算的最慢的那个人, 和分发纸张的速度。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

多gpu训练梯度如何计算,求和是否要求平均 的相关文章

  • 小程序中关于红包雨的实现

    一 原型依据 在我这个项目中小程序端所需要实现的只有红包雨的下落动画和通屏背景图的兼容 关于红包点击金额的计算是由后端实现的 首先来看下需要实现的效果图 二 实现代码 首先是第一次进入的页面 在这个页面的时候会进行静默登录 静默登录成功的话
  • webpack : 无法加载文件 C:\Program Files\nodejs\webpack.ps1

    webpack 无法加载文件 C Program Files nodejs webpack ps1 1 问题 2 解决办法 1 问题 使用webpack打包是报错如下 webpack 无法加载文件 C Program Files nodej
  • LINUX系统下:Cuda+Cudnn+Tensorflow-GPU环境配置学习总结

    1 cuda cudnn安装 1 1下载cuda 1 1 1查看系统支持的cuda版本 可以安装低于该版本的 不能超过该版本 nvidia smi 1 1 2下载cuda cuda历史版下载 1 2 3安装 1 找到下载的cuda文件所在的
  • CUnit用法总结

    简介 CUnit是一个用C语言写的单元测试库 它是一个简单的测试框架 提供了丰富的断言语句来测试常用的数据类型 此外 对于跑测试用例和反馈测试结果 CUnit都有一些不同的接口 它可以编译成动态库或者静态库 基本框架 CUnit是一个可以跨
  • sqlhelper集成dynamic多数据源的分页问题(非教学向)

    一 问题描述 最近接手 顶锅 了公司的框架维护工作 第一项任务就是集成dynamic多数据源框架 dynamic官方使用文档 本文不是教学 有兴趣的小伙伴可以自己查阅文档 集成dynamic之后 一切都很顺利 但是测试到SQLHelper框
  • SQLITE学习之SQLITE基础知识(一)

    1 SQLITE常见命令 sqlite常用命令被称为 SQLite 的点命令 这些命令的不同之处在于它们不以分号 结束 我们只需在ubuntu终端界面上的命令提示符 下键入一个简单的 sqlite3 命令 在 SQLite 命令提示符 gt
  • python解数独

    在学典型优化问题模型与算法的时候发现 暑假的解数独的部分 可以设计三个模型 比如唯一余数 基础摒除法等 让他们循环运行 同时设计一个步数 多次循环来找到步数最少的解题路径 当然还会遇到这三个模型解决不了的问题 这时候就需要增加模型了 sud
  • docker save和docker export区别

    两者区别 docker save用于导出镜像到文件 包含镜像元数据和历史信息 docker export用于将当前容器状态导出至文件 类似快照 所以不包含元数据及历史信息 体积更小 此外从容器快照导入时也可以重新指定标签和元数据信息 一 导
  • LINUX 系统编程之文件IO

    文件IO 属于系统IO 是系统内核向用户空间提供的接口 直接调用内核提供的系统调用函数 头文件是unistd h 1 open char s flag mode 在fcntl h头文件种声明 函数的作用 创建或打开某个文件 最多包含三个参数
  • java bean对象属性复制,将一个对象的属性值赋值给另一个对象,对象之间的复制方法

    注意依赖 springframework下的复制顺序为 目标对象 新对象 import org springframework beans BeanUtils public static void main String args Inte
  • java 获取 sessionid_通过sessionid获取session方法

    使用HttpSessionListener来监听session的创建和销毁 首先创建HttpSessionListener的实现类 SessionListeners java packagecom test importjava util
  • 【详细学习Docker部署搭建高可用的MySQL集群环境】

    一 MySQL高可用集群搭建 MySQL集群搭建在实际项目中是非常必须的 接下来我们来学习通过PXC Percona XtraDB Cluster 来实现强一致性数据库集群搭建 1 1 MySQL集群搭建 1 1 1 中央仓库查找相关镜像
  • 三年级计算机考试题目及答案,三年级信息技术试题及答案.doc

    三年级信息技术试题及答案 三年级信息技术期末试题 学校 班级 姓名 分数 一 单项选择题 共10题 每小题4 共计40分 1 计算机的心脏是 显示屏 鼠标 2 输入汉字时我们需要选择输入法 是我们使用的输入法之一 它的名字叫 五笔输入法 智
  • 【刷题笔记7】LeetCode 54. 螺旋矩阵(数组模拟)

    用分享的方式成长 用有趣的眼光看世界 欢迎来到12 26 25的博客 热爱编码 算法 知识总结 不定期更新有趣 有料 有营养内容 让我们共同学习 共同进步 系列索引 刷题笔记0 系列目录索引 持续更新 推荐收藏 本题题目 LeetCode
  • NVIDIA FasterTransformer

    NVIDIA FasterTransformer NVIDIA GPU计算专家团队针对transformer推理提出了性能优化方案 FasterTransformer 截止到2022年7月 这套方案支持的模型涵盖了BERT GPT Long
  • Mybatis整合Spring -- typeAliasesPackage

    Mybatis 整合 Spring integration MapperScannerConfigurer Mybatis整合Spring 根据官方的说法 在ibatis3 也就是Mybatis3问世之前 Spring3的开发工作就已经完成
  • python时间计算 周开始第一天和结束天 通过年周计算

    python def year mon for check year week 通过年周获取当前月 按每周最后一天的月份比对 最后一天为周日 end year week str year str week 0 end week result

随机推荐

  • xss入门闯关详解6-10关

    继续进行6 10关 第6关 简单的尝试之后发现闭合掉了 尝试空格或者大小写 tab绕过 大小写成功绕过 Onclick alert 1 第七关 老样子 value gt click alert 1 gt value gt lt gt ale
  • 取消idm下载器和google浏览器的关联(让谷歌浏览器禁止使用idm插件)

    https jingyan baidu com article 597035529ae46b8fc107405d html IDM下载安装成功之后 会自动默认关联你电脑上的所有浏览器 在使用浏览器下载的时候自动会变成IDM下载 如果不想让I
  • 2018-2019-2 网络对抗技术 20165335 Exp2 后门原理与实践

    一 基础问题回答 1 例举你能想到的一个后门进入到你系统中的可能方式 钓鱼网站 搞一个假网站 假淘宝 盗版电影 文库下载文档什么的 下载东西的时候把带隐藏的后门程序附带下载进去 自启动 反弹连接 搞一个小网站 用iframe标签跳转到危险网
  • 自动化测试工具Parasoft c++ test v2021.1全新发布,简化嵌入式测试

    随着Parasoft C C test 2021 1的发布 嵌入式测试和开发团队获得了现代高度自动化CI CD管道的速度和效率 最新版本为团队提供了完全集成的静态和单元测试 以实现持续合规性和质量的交付 新版本继续全面支持最新的合规标准 包
  • 几种查找的时间复杂度

    1 顺序查找 1 最好情况 要查找的第一个就是 时间复杂度为 O 1 2 最坏情况 最后一个是要查找的元素 时间复杂度未 O n 3 平均情况下就是 n 1 2 所以总的来说时间复杂度为 O n 2 二分查找 O log2n gt log以
  • 在Ubuntu上编译安装LLVM

    章节索引 Motivation 环境 Git 下载LLVM源码 CMake 编译 安装 文件链接 验证 后记 Motivation 两周前实验室要求我配置一个叫Speedy js的编译器 配置这个编译器需要先配置好LLVM 根据这个编译器作
  • Unity 流程控制

    异步函数 Invoke 被调用的方法不是立刻执行 而是过一段时间后才执行 注 Invoke是不能接受有参数的方法的 Invoke是受Time timeScale的影响 所以当Time timeScale 0 的时候 Invoke是无效的 因
  • python测试url是否可访问,网站是否连通的方法

    目录 前言 1 requests库 1 1 传参 1 2 响应内容 2 python web 前言 一般这种方法用在校验 比如 前端界面传回后端的url 如果返回值不是200 不保存其值 调用的接口不通 直接返回非200 爬虫网站 验证ur
  • C# Dictionary用法总结

    1 常规用法 增加键值对之前需要判断是否存在该键 如果已经存在该键而且不判断 将抛出异常 所以这样每次都要进行判断 很麻烦 在备注里使用了一个扩展方法 public static void DicSample1 Dictionary
  • 【今日CV 计算机视觉论文速览 第111期】Fri, 3 May 2019

    今日CS CV 计算机视觉论文速览 Fri 3 May 2019 Totally 29 papers 上期速览 更多精彩请移步主页 Interesting Single Image Portrait Relighting单图肖像光照重建 本
  • 【Java基础】使用HashMap和For循环查找数据并且统计耗费时长

    准备一个ArrayList其中存放3000000 三百万个 Hero对象 其名称是随机的 格式是hero 4位随机数 hero 3229 hero 6232 hero 9365 因为总数很大 所以几乎每种都有重复 把名字叫做 hero 55
  • ICS大作业--hello程序人生

    摘 要 Hello是每个程序员第一个接触的程序 在本文中利用计算机系统所学的知识 基于Linux平台 通过gcc objdump gdb edb等工具 从c源程序的预处理开始 跟踪分析程序编译 汇编 链接 进程管理 存储管理 I O管理整个
  • 基于Matlab的激光雷达与单目摄像头联合外参标定

    1 背景介绍 目前团队正在与某主机厂合作开发L4级自动驾驶出租车项目 所用的传感器包括80线激光雷达和单目摄像头 为了充分利用Lidar和Cam各自的优势 设计了一种融合算法 需要获得Lidar2Camera的联合外参 前期使用Autowa
  • Java数组&二维数组

    Java数组 二维数组 1 一维数组 1 1 数组介绍 数组就是存储数据长度固定的容器 存储多个数据的数据类型要一致 1 2 数组的定义格式 1 2 1 第一种格式 数据类型 数组名 示例 int arr double arr char a
  • python调用多级目录中的文件_python复制多层目录下的文件至其他盘符对应的目录中...

    tmp c cmd js d TZT2 0 js d TZT js d TZT 346 226 207 346 241 243 350 257 264 346 230 216 json d c modules config js d css
  • Go_包、工程管理

    包 包其实就是文件夹 把文件分类放到不同的包利于管理 作用 如果把所有的代码都放在一个文件中 后续的可维护性 阅读性都比较差 所以可以使用包的来区分不同的模块 功能分别放在不同包中 然后其它的文件使用到功能就调用就可以了 在同一个包下是函数
  • android studio在XML预览中出现Rendering Problems问题

    http blog csdn net u014365838 article details 52078501 如图所示 出现问题的原因是style不和规范 虽然没有大问题 编译也可以通过 不过总有些碍眼 解决方法 1 在styles xml
  • apache-ant build.xml 实例

    文章目录 version apache ant 1 9 15 build xml
  • php发送邮件样式_php简单实现发送带附件的邮件

    这篇文章主要介绍了php简单实现发送带附件的邮件 涉及附件上传及邮件发送的相关技巧 需要的朋友可以参考下 本文实例讲述了php简单实现发送带附件的邮件 分享给大家供大家参考 具体如下 下面是静态html代码 带附件的邮件发送 发送人 收件人
  • 多gpu训练梯度如何计算,求和是否要求平均

    作者 智星云服务 链接 https www zhihu com question 271226455 answer 1521784627 来源 知乎 著作权归作者所有 商业转载请联系作者获得授权 非商业转载请注明出处 作者 itsAndy