PyTorch分布式训练

2023-11-01

PyTorch切分模型和数据两种方法:

  1. DataParallel是单进程多线程的,只用于单机情况;

  1. DistributedDataParallel支持模型并行,同时适用于单机和多机情况。多进程,每个进程都有独立的优化器,执行自己的更新过程,梯度通过通信传递到每个进程(GPU之间只传递梯度),所有执行的内容是相同的。

DistributedDataParallel内部机制

https://pytorch.org/docs/stable/notes/ddp.html

大致流程

Figure 1. 多机多卡并行训练流程

  1. 初始化进程组。

  1. 创建分布式并行模型,每个进程都会有相同的模型和参数。

  1. 创建数据分发Sampler,使每个进程加载一个Batch中不同部分的数据。

  1. 每个进程前向传播并各自计算梯度。

  1. 模型某一层的参数得到梯度后会马上进行通讯并进行梯度平均

  1. 各GPU更新模型参数。

初始化
# 初始化分布式环境
dist.init_process_group(                                   
        backend='nccl',                                         
        init_method='env://',                                   
        world_size=args.world_size,                              
        rank=rank                                               
    )

其中backend参数指定通信后端,包括mpi, gloo, nccl。nccl是Nvidia提供的官方多卡通信框架,相对比较高效;mpi也是高性能计算常用的通信协议,不过需要自己安装MPI实现框架,比如OpenMPI;gloo是内置通信后端,但是不够高效。init_method指的是如何初始化,以完成刚开始的进程同步;这里设置的env://指的是环境变量初始化方式,需要在环境变量中配置4个参数:MASTER_PORT,MASTER_ADDR,WORLD_SIZE,RANK。

初始化即建立一个默认的分布式进程组 (distributed process group),这个group同时会初始化Pytorch的torch.distributed包,后续可以直接用torch.distributed的API进行分布式的基本操作。

import torch.distributed as dist
模型侧
torch.cuda.set_device(opt.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[opt.local_rank])
rank = dist.get_rank()
device = torch.device(f'cuda:{opt.local_rank}')

dist.barrier()  # synchronizes all processes

rank是全局的进程序号,要根据这个参数来设置每个进程所使用的device设备。

local_rank是指的训练进程在当前节点的序号,即所采用的GPU编号。对于local_rank的获取有两种方式,方式一是在训练脚本添加一个命令行参数,程序启动时会对其自动赋值;方式二是采用torch.distributed.launch启动时加上--use_env=True,该情况下会设置LOCAL_RANK这个环境变量。

# 方式一
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)
args = parser.parse_args()
local_rank = args.local_rank

# 方式二
local_rank=int(os.environ["LOCAL_RANK"])
数据侧
data_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True)
dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batchSize,
        shuffle=False,
        num_workers=int(opt.nThreads),
        drop_last=True,
        sampler=data_sampler
        )

注意训练循环过程的每个epoch开始时调用data_sampler.set_epoch(epoch),主要是为了保证每个epoch的划分是不同的,其它的训练代码都保持不变。

有效batch_size其实是batch_size_per_gpu * world_size (world_size为节点总数nnodes乘以每个节点的GPU数nproc_per_node)。

启动方式
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE_PER_NODE
           --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" --master_port=1234
           YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
           and all other arguments of your training script)

所有的进程需要知道进程0的IP地址以及端口,这样所有进程可以在开始时同步,一般情况下称进程0是master进程,通常在进程0中打印信息或者保存模型。

参考链接

https://github.com/pytorch/examples/tree/main/imagenet

https://zhuanlan.zhihu.com/p/113694038

https://support.huaweicloud.com/intl/zh-cn/develop-modelarts/modelarts-distributed-0008.html

https://blog.csdn.net/ytusdc/article/details/122091284

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch分布式训练 的相关文章

随机推荐

  • tomcat中仅启动指定的项目,不启动其它项目

    使用tomcat调试项目时 若部署过多项目会导致启动时间很长 这时指定自己想启动的项目就显得很有必要了 具体方法如下 一 打开tomcat中的server xml配置文件 在Host节点中添加或修改如下属性 deployXML false
  • QT系列第3节 QT中混合UI设计

    QT开发过程中 经常使用Qt designer设计器和代码方式结合来及进行ui设计 本节将介绍这两种方式混合进行ui开发 目录 1 工程添加图片资源 2 添加菜单 3 添加工具栏 4 简单文本编辑器实现 5 QT Creator常用快捷键
  • scrapy设置代理ip(精简版)

    在middlewares py文件中 添加下面的代码 import scrapy from scrapy import signals import random class ProxyMiddleware object def proce
  • 爬取淘宝价格

    爬取淘宝价格 from selenium import webdriver from lxml import etree from time import sleep 实例化一个浏览器对象 bro webdriver Chrome exec
  • [调用函数]

    注 梳理 整理 用来帮助自己学习 如有错误 请指出 1 编写一个函数 该函数接受两个整数作为参数并返回它们的和 在主函数中调用该函数并输出结果 示例输入 5 7 示例输出 12 解题思路 首先需要定义一个函数来实现两个整数的加法 函数的返回
  • 对于进程同步和异步的理解

    多进程并发执行具有异步的特性 进程异步就是指一个以上的进程在并发执行时具有的异步特型 就比如说两个进程之间指令的执行顺序是不确定的 具有很强的随机性 举个例子 现在有两个并发执行的进程 A 和 B 各自都有n条指令需要执行 然而 我的CPU
  • python后端学习(七)HTTP协议、实现WEB服务器

    HTTP协议简介 浏览器 gt 服务器发送的请求格式如下 GET HTTP 1 1 请求方式 路径 协议及版本 Host 127 0 0 1 8080 请求的地址 Connection keep alive 长连接 Accept text
  • RS485模块的介绍及引脚连线说明

    RS485模块通讯 1 RS 485简介 2 SP3485芯片及应用 1 RS 485简介 RS 485采用平衡发送和差分接收 因此具有抑制共模干扰的能力 以下是某宝上RS485模块的截图 应用特点 传输数据速度快 高达10Mbps 即10
  • 【老生谈算法】matlab实现粒子滤波及实现

    粒子滤波及matlab实现 1 文档下载 本算法已经整理成文档如下 有需要的朋友可以点击进行下载 说明 文档 点击下载 本算法文档 老生谈算法 matlab实现粒子滤波及实现 doc 更多matlab算法原理及源码详解可点击下方文字直达 5
  • 《Apache MINA 2.0 用户指南》第六章:传输

    最近准备将Apache MINA 2 0 用户指南英文文档翻译给大家 但是我偶然一次百度 发现 Defonds 这位大牛已经翻译大部分文档 原文链接 http mina apache org mina project userguide c
  • 从单向链表中删除指定值的节点-牛客网

    题目描述 输入一个单向链表和一个节点的值 从单向链表中删除等于该值的节点 删除后如果链表中无节点则返回空指针 链表结点定义如下 struct ListNode int m nKey ListNode m pNext 详细描述 本题为考察链表
  • 电脑知识【自用】

    1 解决BIOS误删Windows Boot Manager 方法一 通过Grub进行修复 通过以下步骤解决 重启电脑 按F12进入BIOS SETUP 进入Boot Sequence 查看Windows Boot Manager是否丢失
  • Sftp实现文件的上传下载(com.jcraft.jsch依赖解决解决:Could not parse response code.Server Reply: SSH-2.0-OpenSSH_5.3)

    依赖如下
  • 给vcenter中的Esxi主机网络添加VLAN

    1 使用vSphere Client连接到VMware ESXi Server 在 配置 网络 中 可以看到 当前有两个虚拟交换机 并且为该虚拟交换机分配了管理地址10 10 228 81 点击 添加网络 如图所示 2 添加配置向导 在网络
  • 圆检测学习笔记

    目录 边缘检测 再检测圆 霍夫圆检测 转自 深度OpenCV开发之精准找圆 GitHub zikai1 CircleDetection circle detection inscribed triangles image processin
  • Hive中自定义UDF,UDTF实例以及三种自定义函数的区别

    Hive中有三种UDF 分类 1 用户定义函数 user defined function UDF 2 用户定义聚集函数 user defined aggregate function UDAF 3 用户定义表生成函数 user defin
  • Jpcap环境安装配置

    1 Jpcap 下载地址 链接地址不可用 问度娘JpcapSetup 0 7 exe 然后下载 2 WinPcap 下载地址 http www winpcap org install default htm 3 Libpcap 下载地址 h
  • vue+axios+element+ui实现手机发送验证码及校验验证码功能

    配合express拿到官网接口 1 首先布局 使用element表单输入框 标签上面相应的绑定了一些事件以及校验规则 下面逻辑代码中有注释 div class wrap div
  • 按钮控件之4---QToolButton 工具按钮控件

    一 设置和基本显示 QWidget w QToolButton pb1 new QToolButton w 设置文字 setText 设置图标 setIcon 改变图标大小 setIconSize 设置提示文本 setToolTip pb1
  • PyTorch分布式训练

    PyTorch切分模型和数据两种方法 DataParallel是单进程多线程的 只用于单机情况 DistributedDataParallel支持模型并行 同时适用于单机和多机情况 多进程 每个进程都有独立的优化器 执行自己的更新过程 梯度