华为开源自研AI框架昇思MindSpore应用案例:分布式并行训练基础样例(CPU)

2023-05-16

目录

  • 一、环境准备
    • 1.进入ModelArts官网
    • 2.使用ModelArts体验实例
  • 二、准备环节
    • 1.下载数据集
    • 2.配置分布式环境
  • 三、加载数据集
  • 四、定义模型
  • 五、启动训练

本教程主要讲解,如何在CPU平台上,使用MindSpore进行数据并行分布式训练,以提高训练效率。
完整的样例代码:distributed_training_cpu

目录结构如下:

bash └─sample_code
    ├─distributed_training_cpu
    │      resnet.py
    │      resnet50_distributed_training.py
    │      run.sh 

其中,resnet.py和resnet50_distributed_training.py是训练网络定义脚本,run.sh是分布式训练执行脚本。

如果你对MindSpore感兴趣,可以关注昇思MindSpore社区

在这里插入图片描述

在这里插入图片描述

一、环境准备

1.进入ModelArts官网

云平台帮助用户快速创建和部署模型,管理全周期AI工作流,选择下面的云平台以开始使用昇思MindSpore,获取安装命令,安装MindSpore2.0.0-alpha版本,可以在昇思教程中进入ModelArts官网

在这里插入图片描述

选择下方CodeLab立即体验

在这里插入图片描述

等待环境搭建完成

在这里插入图片描述

2.使用ModelArts体验实例

进入昇思MindSpore官网,点击上方的安装

在这里插入图片描述

获取安装命令

在这里插入图片描述
在ModelArts中打开一个Terminal,输入安装命令

conda install mindspore=2.0.0a0 -c mindspore -c conda-forge

在这里插入图片描述

再点击侧边栏中的Clone a Repository,输入

https://gitee.com/mindspore/docs.git

在这里插入图片描述

可以看到docs项目导入成功

在这里插入图片描述

二、准备环节

1.下载数据集

本样例采用CIFAR-10数据集,由10类32*32的彩色图片组成,每类包含6000张图片,其中训练集共50000张图片,测试集共10000张图片。

将下载的数据集上传到,解压后文件夹为cifar-10-batches-bin

在这里插入图片描述

tar -zxvf cifar-10-binary.tar.gz 

在这里插入图片描述

在这里插入图片描述

注意:如果你使用的是ModelArts,接下来介绍内容都无需配置,可以跳到 五、启动训练,直接开始训练模型

2.配置分布式环境

CPU上数据并行主要分为单机多节点和多机多节点两种并行方式(一个训练进程可以理解为一个节点)。在运行训练脚本前,需要搭建组网环境,主要是环境变量配置和训练脚本里初始化接口的调用。

环境变量配置如下:


export MS_WORKER_NUM=8                # Worker number
export MS_SCHED_HOST=127.0.0.1        # Scheduler IP address
export MS_SCHED_PORT=6667             # Scheduler port
export MS_ROLE=MS_WORKER              # The role of this node: MS_SCHED represents the scheduler, MS_WORKER represents the worker

其中,

  • MS_WORKER_NUM:表示worker节点数,多机场景下,worker节点数是每机worker节点之和。
  • MS_SCHED_HOST:表示scheduler节点ip地址。
  • MS_SCHED_PORT:表示scheduler节点服务端口,用于接收worker节点发送来的ip和服务端口,然后将收集到的所有worker节点ip和端口下发给每个worker。
  • MS_ROLE:表示节点类型,分为worker(MS_WORKER)和scheduler(MS_SCHED)两种。不管是单机多节点还是多机多节点,都需要配置一个scheduler节点用于组网。

训练脚本里初始化接口调用如下:


import mindspore as ms
from mindspore.communication import init

ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU")
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
ms.set_ps_context(enable_ssl=False)
init()

其中,

  • ms.set_context(mode=context.GRAPH_MODE, device_target=“CPU”):指定模式为图模式(CPU上PyNative模式下不支持并行),设备为CPU。
  • ms.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
    gradients_mean=True):指定数据并行模式,gradients_mean=True表示梯度归约后会进行一个求平均,当前CPU上梯度归约仅支持求和。
  • ms.set_ps_context:配置安全加密通信,可通过ms.set_ps_context(enable_ssl=True)开启安全加密通信,默认为False,关闭安全加密通信。
  • init:节点初始化,初始化完成表示组网成功。

三、加载数据集

分布式训练时,数据集是以数据并行的方式导入的。下面我们以CIFAR-10数据集为例,介绍以数据并行方式导入CIFAR-10数据集的方法,data_path是指数据集的路径,即cifar-10-batches-bin文件夹的路径。

样例代码如下


import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms
from mindspore.communication import get_rank, get_group_size

def create_dataset(data_path, repeat_num=1, batch_size=32):
    """Create training dataset"""
    resize_height = 224
    resize_width = 224
    rescale = 1.0 / 255.0
    shift = 0.0

    # get rank_id and rank_size
    rank_size = get_group_size()
    rank_id = get_rank()
    data_set = ds.Cifar10Dataset(data_path, num_shards=rank_size, shard_id=rank_id)

    # define map operations
    random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4))
    random_horizontal_op = vision.RandomHorizontalFlip()
    resize_op = vision.Resize((resize_height, resize_width))
    rescale_op = vision.Rescale(rescale, shift)
    normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023))
    changeswap_op = vision.HWC2CHW()
    type_cast_op = transforms.TypeCast(ms.int32)

    c_trans = [random_crop_op, random_horizontal_op]
    c_trans += [resize_op, rescale_op, normalize_op, changeswap_op]

    # apply map operations on images
    data_set = data_set.map(operations=type_cast_op, input_columns="label")
    data_set = data_set.map(operations=c_trans, input_columns="image")

    # apply shuffle operations
    data_set = data_set.shuffle(buffer_size=10)

    # apply batch operations
    data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)

    # apply repeat operations
    data_set = data_set.repeat(repeat_num)

    return data_set

与单机不同的是,在构造Cifar10Dataset时需要传入num_shards和shard_id参数,分别对应worker节点数和逻辑序号,可通过框架接口获取,如下:

  • get_group_size:获取集群中worker节点数。
  • get_rank:获取当前worker节点在集群中的逻辑序号。

数据并行模式加载数据集时,建议对每卡指定相同的数据集文件,若是各卡加载的数据集不同,可能会影响计算精度。

四、定义模型

数据并行模式下,网络定义与单机写法一致,可参考ResNet网络样例脚本。

优化器、损失函数及训练模型定义可参考训练模型定义。

完整训练脚本代码参考样例,下面列出训练启动代码。


import os
import mindspore as ms
import mindspore.nn as nn
from mindspore import train
from mindspore.communication import init
from resnet import resnet50

def train_resnet50_with_cifar10(epoch_size=10):
    """Start the training"""
    loss_cb = train.LossMonitor()
    data_path = os.getenv('DATA_PATH')
    dataset = create_dataset(data_path)
    batch_size = 32
    num_classes = 10
    net = resnet50(batch_size, num_classes)
    loss = SoftmaxCrossEntropyExpand(sparse=True)
    opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
    model = ms.Model(net, loss_fn=loss, optimizer=opt)
    model.train(epoch_size, dataset, callbacks=[loss_cb], dataset_sink_mode=True)


if __name__ == "__main__":
    ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU")
    ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
    ms.set_ps_context(enable_ssl=False)
    init()
    train_resnet50_with_cifar10()

脚本里create_dataset和SoftmaxCrossEntropyExpand接口引用自distributed_training_cpu,
resnet50接口引用自ResNet网络样例脚本。

五、启动训练

在CPU平台上,以单机8节点为例,执行分布式训练。

进入到 /home/ma-user/work/docs/docs/sample_code/distributed_training_cpu目录下

在这里插入图片描述

在这里插入图片描述

通过以下shell脚本启动训练,指令bash run.sh /dataset/cifar-10-batches-bin,可以看到已经训练成功了

在这里插入图片描述

(PyTorch-1.8) [ma-user distributed_training_cpu]$bash run.sh cifar-10-batches-bin
==============================================================================================================
Please run the script with dataset path, such as: 
bash run.sh DATA_PATH
For example: bash run.sh /path/dataset
It is better to use the absolute path.
==============================================================================================================
scheduler start success!
worker 0 start success with pid 8240
worker 1 start success with pid 8241
worker 2 start success with pid 8242
worker 3 start success with pid 8243
worker 4 start success with pid 8244
worker 5 start success with pid 8245
worker 6 start success with pid 8246
worker 7 start success with pid 8247

#!/bin/bash
# run data parallel training on CPU

echo "=============================================================================================================="
echo "Please run the script with dataset path, such as: "
echo "bash run.sh DATA_PATH"
echo "For example: bash run.sh /path/dataset"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
set -e
DATA_PATH=$1
export DATA_PATH=${DATA_PATH}

export MS_WORKER_NUM=8
export MS_SCHED_HOST=127.0.0.1
export MS_SCHED_PORT=8117

# Launch 1 scheduler.
export MS_ROLE=MS_SCHED
python3 resnet50_distributed_training.py >scheduler.txt 2>&1 &
echo "scheduler start success!"

# Launch 8 workers.
export MS_ROLE=MS_WORKER
for((i=0;i<${MS_WORKER_NUM};i++));
do
    python3 resnet50_distributed_training.py >worker_$i.txt 2>&1 &
    echo "worker ${i} start success with pid ${!}"
done

其中,resnet50_distributed_training.py为定义的训练脚本。

对于多机多节点场景,需要在每个机器上按照这种方式,启动相应的worker节点参与训练,但scheduler节点只有一个,只需要在其中一个机器上(即MS_SCHED_HOST)启动即可。

定义的MS_WORKER_NUM值,表示需要启动相应数量的worker节点参与训练,否则组网不成功。

虽然针对scheduler节点也启动了训练脚本,但scheduler主要用于组网,并不会参与训练。

训练一段时间,打开worker_0日志,训练信息如下:

在这里插入图片描述

(PyTorch-1.8) [ma-user distributed_training_cpu]$tail -f worker_0.txt 

……
epoch: 1 step: 1, loss is 1.4686084
epoch: 1 step: 2, loss is 1.3278534
epoch: 1 step: 3, loss is 1.4246798
epoch: 1 step: 4, loss is 1.4920032
epoch: 1 step: 5, loss is 1.4324203
epoch: 1 step: 6, loss is 1.432581
epoch: 1 step: 7, loss is 1.319618
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

华为开源自研AI框架昇思MindSpore应用案例:分布式并行训练基础样例(CPU) 的相关文章

  • Ubuntu虚拟机可以上网,可以ping网络,但是无法update和install,显示不能连接或者无网络

    此方法为我找遍了网上全部解决方案之后还没有解决掉 xff0c 自己琢磨出来的其中一种方法 错误情况 xff1a 可以上浏览器看视频 xff0c 但是不能apt install vim或者gcc 解决方案 1 打开文件夹 2 输入 或者进入
  • 13:SpringBoot跨域解决方案-Java Spring

    目录 13 1 CorsFilter13 2 64 CrossOrigin13 3 WebMvcConfigurer 13 1 CorsFilter SpringBoot设置CORS的的本质都是通过设置响应头信息来告诉前端该请求是否支持跨域
  • 14:Servlet并发机制-Java Spring

    目录 14 1 并发14 2 Servlet并发机制14 3 Tomcat并发特点14 4 Tomcat线程模型 14 1 并发 并发 xff08 Concurrent xff09 是指多个任务交替执行的现象 xff0c 把CPU运行时间划
  • 手写字体识别实验-Python课程设计

    安装python 打开手写识别文件夹中的安装包文件夹 xff0c 双击python3 7 1可执行文件 xff0c 进行安装 弹出窗口 第一步 xff0c 勾选第二个复选框 Add Python 3 7 to PATH xff0c 然后点击
  • 生产企业原材料订购与运输的研究-数据处理课程设计

    目录 摘要1 引言2 规划问题说明3 问题重述3 1 问题分析3 2 数据说明3 3 模型假设3 4 符号说明 4 实验及分析4 1 问题一模型的建立与求解4 2 问题二模型的建立与求解 5 总结5 1 模型的优点5 2 模型的缺点 参考文
  • 信号发生器-电路与电子技术课程设计

    目录 1 设计任务与要求1 1 设计任务1 2 设计要求 2 方案设计与论证2 1 方案设计2 2 论证 3 信号发生器设计与计算3 1 信号发生器设计3 2 方波振荡电路图3 3 三角波振荡电路图3 4 参数计算 4 总原理图及元器件清单
  • 增益可控放大电路-电路与电子技术课程设计

    目录 1 设计任务与要求1 1 设计任务1 2 设计要求 2 方案设计与论证2 1 方案设计2 2 论证 3 放大电路设计与计算3 1 放大电路设计3 2 电子开关切换电路设计3 3 六档控制电路3 4 参数计算 4 总原理图及元器件清单4
  • 超声波测距实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习超声波测距传感器的使用方法 xff0c 了解超声波测距传感器的原理和电路及实际应用 xff0c 了解超声波测距传感器的基本操作
  • 光敏传感器实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习光敏传感器的使用方法 xff0c 了解光敏传感器的基本实验原理和实际应用 xff0c 熟练掌握光敏传感器实验的操作步骤 xff
  • 红外反射传感器实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习红外反射传感器的使用方式 xff0c 了解红外反射传感器的实验原理和实际应用 xff0c 学习并理解Modbus数据格式所代表
  • 酒精传感器实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习酒精传感器MQ 3的使用方法 xff0c 了解酒精传感器的实验原理和实际应用 xff0c 了解酒精传感器的基本操作模式 xff
  • hdoj 1575 Tr A (矩阵快速幂)

    Tr A Time Limit 1000 1000 MS Java Others Memory Limit 32768 32768 K Java Others Total Submission s 4549 Accepted Submiss
  • MapReduce排序过程

    排序是MapReduce框架中最重要的操作之一 MapTask和ReduceTask均会对数据按照key 进行排序 该操作属于Hadoop 的默认行为 xff0c 任何应用程序中的数据均会被排序 xff0c 而不管逻辑上是否需要 默认排序是
  • 温湿度传感器实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习温湿度传感器的使用方法 xff0c 了解温湿度传感器的基本实验原理和实际应用 xff0c 熟练掌握温湿度传感器的基本步骤 xf
  • 烟雾检测传感器实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习烟雾检测传感器的原理及检测方式 xff0c 了解烟雾检测传感器的实验原理和技术指标 xff0c 熟练掌握烟雾检测传感器的工作步
  • 4:Servlet-Java Web

    目录 4 1 Servlet简介4 2 HTTP协议4 3 Servlet与JSP4 4 Servlet处理的基本流程4 5 Servlet 容器4 6 Servlet程序实现 4 1 Servlet简介 Servlet是用Java语言编写
  • 5:Servlet程序-Java Web

    目录 5 1 Servlet要求5 2 创建Servlet5 3 第一个Servlet5 4 Servlet编译5 5 Servlet配置 5 1 Servlet要求 如果要开发一个可以处理HTTP请求的Servlet程序 xff0c 首先
  • 6:部署Servlet-Java Web

    目录 6 1 部署Servlet6 2 请求Servlet6 3 找不到servlet包6 4 Servlet映射的细节 6 1 部署Servlet 部署就是把Servlet的字节码文件放在适当的地方 为了在浏览器上访问Servlet xf
  • 7:Servlet表单-Java Web

    目录 7 1 Servlet响应7 2 Servlet获取客户端参数7 3 Servlet接受表单数据 7 1 Servlet响应 通过response对象对用户进行响应 创建输出流对象 PrintWriter out 61 respons
  • 8:Servlet生命周期-Java Web

    目录 8 1 Servlet生命周期8 2 Servlet生命周期对应的方法8 3 Servlet的多线程机制 8 1 Servlet生命周期 Servlet程序是运行在服务器端的一段Java程序 xff0c 其生命周期将受到Web容器的控

随机推荐