【Keras】TensorFlow分布式训练

2023-11-13

当我们拥有大量计算资源时,通过使用合适的分布式策略,我们可以充分利用这些计算资源,从而大幅压缩模型训练的时间。针对不同的使用场景,TensorFlow 在 tf.distribute.Strategy 中为我们提供了若干种分布式策略,使得我们能够更高效地训练模型。

一、单机多卡训练: MirroredStrategy

tf.distribute.MirroredStrategy 是一种简单且高性能的,数据并行的同步式分布式策略,主要支持多个 GPU 在同一台主机上训练。使用这种策略时,我们只需实例化一个 MirroredStrategy 策略:

strategy = tf.distribute.MirroredStrategy()

并将模型构建的代码放入 strategy.scope() 的上下文环境中:

with strategy.scope():
    # 模型构建代码

可以在参数中指定设备,如:

strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:1"])

即指定只使用第 0、1 号 GPU 参与分布式策略。

以下代码展示了使用 MirroredStrategy 策略,在 TensorFlow Datasets 中的部分图像数据集上使用 Keras 训练 MobileNetV2 的过程:

import tensorflow as tf
import tensorflow_datasets as tfds

num_epochs = 5
batch_size_per_replica = 64
learning_rate = 0.001

strategy = tf.distribute.MirroredStrategy()
print('Number of devices: %d' % strategy.num_replicas_in_sync)  # 输出设备数量
batch_size = batch_size_per_replica * strategy.num_replicas_in_sync

# 载入数据集并预处理
def resize(image, label):
    image = tf.image.resize(image, [224, 224]) / 255.0
    return image, label

# 使用 TensorFlow Datasets 载入猫狗分类数据集,详见“TensorFlow Datasets数据集载入”一章
dataset = tfds.load("cats_vs_dogs", split=tfds.Split.TRAIN, as_supervised=True)
dataset = dataset.map(resize).shuffle(1024).batch(batch_size)

with strategy.scope():
    model = tf.keras.applications.MobileNetV2(weights=None, classes=2)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss=tf.keras.losses.sparse_categorical_crossentropy,
        metrics=[tf.keras.metrics.sparse_categorical_accuracy]
    )

model.fit(dataset, epochs=num_epochs)

在以下的测试中,我们使用同一台主机上的 4 块 NVIDIA GeForce GTX 1080 Ti 显卡进行单机多卡的模型训练。所有测试的 epoch 数均为 5。使用单机无分布式配置时,虽然机器依然具有 4 块显卡,但程序不使用分布式的设置,直接进行训练,Batch Size 设置为 64。使用单机四卡时,测试总 Batch Size 为 64(分发到单台机器的 Batch Size 为 16)和总 Batch Size 为 256(分发到单台机器的 Batch Size 为 64)两种情况。

数据集 单机无分布式(Batch Size 为 64) 单机四卡(总 Batch Size 为 64) 单机四卡(总 Batch Size 为 256)
cats_vs_dogs 146s/epoch 39s/epoch 29s/epoch
tf_flowers 22s/epoch 7s/epoch 5s/epoch

可见,使用 MirroredStrategy 后,模型训练的速度有了大幅度的提高。在所有显卡性能接近的情况下,训练时长与显卡的数目接近于反比关系。

MirroredStrategy 的步骤如下:

  1. 训练开始前,该策略在所有 N 个计算设备上均各复制一份完整的模型;
  2. 每次训练传入一个批次的数据时,将数据分成 N 份,分别传入 N 个计算设备(即数据并行);
  3. N 个计算设备使用本地变量(镜像变量)分别计算自己所获得的部分数据的梯度;
  4. 使用分布式计算的 All-reduce 操作,在计算设备间高效交换梯度数据并进行求和,使得最终每个设备都有了所有设备的梯度之和;
  5. 使用梯度求和的结果更新本地变量(镜像变量);
  6. 当所有设备均更新本地变量后,进行下一轮训练(即该并行策略是同步的)。

默认情况下,TensorFlow 中的 MirroredStrategy 策略使用 NVIDIA NCCL 进行 All-reduce 操作。

二、多机训练: MultiWorkerMirroredStrategy

多机训练的方法和单机多卡类似,将 MirroredStrategy 更换为适合多机训练的 MultiWorkerMirroredStrategy 即可。不过,由于涉及到多台计算机之间的通讯,还需要进行一些额外的设置。具体而言,需要设置环境变量 TF_CONFIG ,示例如下:

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:20000", "localhost:20001"]
    },
    'task': {'type': 'worker', 'index': 0}
})

TF_CONFIG 由 cluster 和 task 两部分组成:

  • cluster 说明了整个多机集群的结构和每台机器的网络地址(IP + 端口号)。对于每一台机器,cluster 的值都是相同的;
  • task 说明了当前机器的角色。例如, {‘type’: ‘worker’, ‘index’: 0} 说明当前机器是 cluster 中的第 0 个 worker(即 localhost:20000 )。每一台机器的 task 值都需要针对当前主机进行分别的设置。

以上内容设置完成后,在所有的机器上逐个运行训练代码即可。先运行的代码在尚未与其他主机连接时会进入监听状态,待整个集群的连接建立完毕后,所有的机器即会同时开始训练。

请在各台机器上均注意防火墙的设置,尤其是需要开放与其他主机通信的端口。如上例的 0 号 worker 需要开放 20000 端口,1 号 worker 需要开放 20001 端口。

以下示例的训练任务与前节相同,只不过迁移到了多机训练环境。假设我们有两台机器,即首先在两台机器上均部署下面的程序,唯一的区别是 task 部分,第一台机器设置为 {‘type’: ‘worker’, ‘index’: 0} ,第二台机器设置为 {‘type’: ‘worker’, ‘index’: 1} 。接下来,在两台机器上依次运行程序,待通讯成功后,即会自动开始训练流程。

import tensorflow as tf
import tensorflow_datasets as tfds
import os
import json

num_epochs = 5
batch_size_per_replica = 64
learning_rate = 0.001

num_workers = 2
os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:20000", "localhost:20001"]
    },
    'task': {'type': 'worker', 'index': 0}
})
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
batch_size = batch_size_per_replica * num_workers

def resize(image, label):
    image = tf.image.resize(image, [224, 224]) / 255.0
    return image, label

dataset = tfds.load("cats_vs_dogs", split=tfds.Split.TRAIN, as_supervised=True)
dataset = dataset.map(resize).shuffle(1024).batch(batch_size)

with strategy.scope():
    model = tf.keras.applications.MobileNetV2(weights=None, classes=2)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss=tf.keras.losses.sparse_categorical_crossentropy,
        metrics=[tf.keras.metrics.sparse_categorical_accuracy]
    )

model.fit(dataset, epochs=num_epochs)

在以下测试中,我们在 Google Cloud Platform 分别建立两台具有单张 NVIDIA Tesla K80 的虚拟机实例,并分别测试在使用一个 GPU 时的训练时长和使用两台虚拟机实例进行分布式训练的训练时长。所有测试的 epoch 数均为 5。使用单机单卡时,Batch Size 设置为 64。使用双机单卡时,测试总 Batch Size 为 64(分发到单台机器的 Batch Size 为 32)和总 Batch Size 为 128(分发到单台机器的 Batch Size 为 64)两种情况。

数据集 单机单卡(Batch Size 为 64) 双机单卡(总 Batch Size 为 64) 双机单卡(总 Batch Size 为 128)
cats_vs_dogs 1622s 858s 755s
tf_flowers 301s 152s 144s

可见模型训练的速度同样有大幅度的提高。在所有机器性能接近的情况下,训练时长与机器的数目接近于反比关系。

参考资料

TensorFlow 分布式训练

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【Keras】TensorFlow分布式训练 的相关文章

随机推荐

  • Vue自定义全局指令

    当我们使用 Vue 构建应用时 经常需要在模板中添加一些自定义的指令 来实现期望的功能 这些指令可以全局定义 也可以定义在组件内 全局指令是指在应用的 main js 文件或其他入口文件中注册的指令 可以在应用的任何组件中使用 定义全局指令
  • Qt 主窗口点击按钮 弹出另一个自定义窗口

    为将要进行的工作做准备 代码实现功能 单击某个按钮后 弹出对话框 对话框中的内容可自行设计 1 建立一个主界面 主界面中有一个pushbotton按键 2 右键项目 gt 添加新文件 gt Qt设计师界面类 Part II 对话框的模态和非
  • Kaldi数据下载很慢

    运行kaldi 的run sh文件时 数据集在openslr上 数据集比较大 例如aishell 总共15G 国内网络情况下下载比较慢 1 修改为国内地址 例如 aishell 默认的run sh里写的是www openslr org re
  • html5 颜色随机变化,每次在HTML5 Canvas的.fillStyle中使用时,将画布图案随机化为不同的颜色(randomizing a canvas pattern to be a diff...

    每次在HTML5 Canvas的 fillStyle中使用时 将画布图案随机化为不同的颜色 randomizing a canvas pattern to be a different color every time it is used
  • 文本数据导入HBASE库找不到类com/google/common/collect/Multimap

    文本数据导入HBASE库找不到类com google common collect Multimap 打算将文本文件导入HBASE库 在运行命令的时候找不到类 com google common collect Multima hadoop
  • PyTorch实战使用Resnet迁移学习

    PyTorch实战使用Resnet迁移学习 项目结构 项目任务 项目代码 网络模型测试 项目结构 数据集存放在flower data文件夹 cat to name json是makejson文件运行生成的 TorchVision文件主要存放
  • springboot 整合 mongodb

    前言 前面通过 5 节的内容 我们学习了 mongodb 的使用 这节我们通过学过的知识运用 springboot 整合 mongodb 搭建一个小项目 项目搭建 springboot 基于 maven 项目搭建的具体过程这里就不再赘述了
  • style标签上的scoped属性

    vue中 在vue文件中的style标签上有一个特殊的属性 scoped 布尔值 作用 该属性的作用是将当前标签下的样式私有化 仅对当前组件起作用 只管当前组件和子组件的最外层 不控制自组件 原理 若是给style标签添加了scoped属性
  • 【python与数据分析】Python与数据分析概述

    目录 一 认识数据分析 1 数据分析的方法论与数据分析方法 一 七何分析法 建立框架 二 演绎树分析法 问题分层 三 PEST分析法 设计环境 四 金字塔原理分析法 建立逻辑 五 4P营销理论分析法 业务指导 六 SWOT分析法 战略竞争
  • 计算机专业大二了啥都没学怎么办

    如果您是计算机专业的大二学生 但是还没有学到很多内容 那么您可以考虑以下几点 加强自学 首先 您可以自学一些基本的计算机知识 如操作系统 算法 数据结构等 寻找资源 您可以寻找一些在线的学习资源 如课程 书籍 视频等 加深自己的知识储备 向
  • C语言函数递归例题讲解(超详细~)

    文章目录 递归题型讲解 例题1 例题2 例题3 递归题型讲解 例题1 根据下面递归函数 调用函数Fun 2 返回值是多少 int Fun int n if n 5 return 2 else return 2 Fun n 1 A 2 B 4
  • fit_transform含义

    fit transform X train 找出X train的均值和 标准差 并应用在X train上 对于X test 直接使用transform方法 此时StandardScaler已经保存了X train的均值和标准差
  • telnet 查看端口是否可访问

    1 首先为什么要写这篇文章 说到为什么还得从DNS服务器说起 我在我的电脑上安装了DNS服务器 但是用网络去访问还怎么都访问都不上去 于是我就打开dos窗口 用ping命令查看是否可以ping 如 ping 125 34 49 211 通
  • FANUC机器人程序实例

    FANUC机器人程序实例 PS 1 2 2 3 7 8 8 9 9 10 10 7为圆弧运动 6 1 3 4 4 5 5 6 6 7 7 6 为直线运动 先画图1 循环3次 等待3秒 再画图2 轨迹如上图所示 10个位置在同一平面 程序 位
  • linux配置SSH

    目录 一 ssh简介 二 ssh配置文件 三 ssh远程登录 四 ssh远程登录原理 4 1 对称加密 4 3非对称加密 一 ssh简介 SSH为建立在应用层上的安全协议 SSH是目前非常可靠 专门为远程登录会话和其它网了服务提供安全性的协
  • 程序员微信名昵称_微信营销手段之昵称命名

    这段时间 我玩微信玩的不少 但是主要还是把精力放在我的QQ空间了 但是我不管在微信上还是扣扣空间我都发现了一个怪现象 就是一直有人用广告名称做昵称 啥是广告名字呢 比如她是卖化妆品的 然后她就把微信昵称改成了某某化妆品销售 代购什么的 这种
  • 史上最全Unity3D游戏开发教程,从入门到精通(含学习路线图)

    Unity现在已经用的很广泛啦 可是却一直没有什么美术向的教程 程序员方面的内容在各个论坛都有讨论 但是美术似乎很弱势啊 明明美术也很需要掌握引擎方面的内容嘛 山谷里的野百合还有春天呢 我们美术也要出教程 很多同学想学习unity却不知道怎
  • react中背景图片和图片引入的方法

    有三种引入背景图片的方法 1 div 2 先import引入图片路径 再用es6语法中的 引用 import bgImage from assets images bgImage webp div 3 用require进行路径引用requi
  • 【图像增强】Debiased Subjective Assessment of Real-World Image Enhancement

    最近学习了CVPR2021的一篇文章 真实世界图像增强的去偏主观质量评价 Debiased Subjective Assessment of Real World Image Enhancement 一 前言 图像质量评价 Image Qu
  • 【Keras】TensorFlow分布式训练

    当我们拥有大量计算资源时 通过使用合适的分布式策略 我们可以充分利用这些计算资源 从而大幅压缩模型训练的时间 针对不同的使用场景 TensorFlow 在 tf distribute Strategy 中为我们提供了若干种分布式策略 使得我