手把手教你用MindSpore训练一个AI模型!

2023-11-15

首先我们要先了解深度学习的概念和AI计算框架的角色(https://zhuanlan.zhihu.com/p/463019160),本篇文章将演示怎么利用MindSpore来训练一个AI模型。和上一章的场景一致,我们要训练的模型是用来对手写数字图片进行分类的LeNet5模型

请参考(http://yann.lecun.com/exdb/lenet/)。

图1 MindSpore使用流程

安装MindSpore

MindSpore提供给用户使用的是Python接口(什么是Python,请参考:

https://zhuanlan.zhihu.com/p/462756985),所以我们首先需要安装MindSpore的whl包,安装之后就可以导入(import)MindSpore提供的方法接口了。安装whl包有两种方式:

方式一:进入MindSpore官网,根据自己的设备和Python版本选择安装命令。比如我的Python版本是3.7.5,我的设备是笔记本(CPU),那么我就复制下图红框中的命令进行安装:

图2 MindSpore安装界面

安装过程如下:

图3 MindSpore安装过程

注意:由于MindSpore还依赖于其他的Python三方库,所以在安装过程中,系统还会自动下载、安装其他的Python三方库,如numpy、pillow、scipy等等,安装结束后,如果能 import mindspore 成功,说明MindSpore安装成功了:

图4 MindSpore安装成功

方式二:可以在版本列表中找到对应的whl包,点击就能下载:

图5 MindSpore版本下载列表

下载完成后,把whl包放到自己的目录下,执行 pip install xxx.whl:

图6 MindSpore第二种安装方式

定义模型

安装好MindSpore之后,我们就可以导入MindSpore提供的算子(卷积、全连接、池化等函数:https://zhuanlan.zhihu.com/p/463019160)来构建我们的模型了。可以这么比喻:我们构建一个AI模型就像建一个房子,而MindSpore提供给我们的算子就像是砖块、窗户、地板等基本组件。

图7 定义LeNet5模型

如上图所示,我们用到的“砖块”都是mindspore.nn模块提供的。注意:这里用到了Python的类(class),由②和③两部分组成。我们这里定义的类是class LeNet5,它由初始化函数 __init__(self) 和构造函数construct(self, x)组成。初始化函数定义了我们构造模型所需要用到的算子,比如conv算子、relu算子、flatten算子等等,这些算子都是从mindspore.nn获取的;构造函数就是把我们在初始化函数中导入的算子按顺序排放,构成我们最终的模型。construct()函数的输入就是我们这个模型预测的对象,比如第一章讲的黑白图片像素矩阵;而“return y”中的就是预测的结果,对应于第一章讲到的10分类手写数字数据集,就是一个行10列的数组(这里的是指输入图片的数量,AI模型支持多张图片同时推理)。

导入训练数据集

什么是训练数据集?刚刚定义好的模型是不能对图片进行正确分类的,我们要通过“训练”过程来调整模型的参数矩阵的值。训练过程就需要用到训练样本,也就是打上了正确标签的图片。这就好比我们教小孩儿认识动物,需要拿几张图片给他们看,然后告诉他们这是什么、那是什么,教了几遍之后,小孩儿就能认识了。那么我们训练LeNet5模型就需要用到MNIST数据集,请参考(http://yann.lecun.com/exdb/mnist/)。这个数据集由两部分组成:训练集(6万张图片)和测试集(1万张图片),都是0~9的黑白手写数字图片。训练集是用来训练AI模型的,测试集是用来测试训练后的模型分类准确率的。

下载得到的数据集最初是压缩文件,还不能直接传给MindSpore的训练接口使用,我们要先用MindSpore提供的数据处理接口把他们读进来:

import mindspore.dataset as ds
mnist_ds = ds.MnistDataset(data_path)  # 导入下载的MNIST数据集

然后进行数据增强(比如把图片大小转化成相同的尺寸、像素值标准化、归一化等操作),提升训练效率:

import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype

# 定义数据增强函数
def create_dataset(data_path, batch_size=32):  # batch_size是每一步训练使用的图片数量,一般取32
    """
    create dataset for train or test

    Args:
        data_path (str): Data path
        batch_size (int): The number of data records in each group
    """
    # define dataset
    mnist_ds = ds.MnistDataset(data_path)  # 导入下载的MNIST数据集
    # define some parameters needed for data enhancement and rough justification
    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081

    # according to the parameters, generate the corresponding data enhancement method
    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
    rescale_op = CV.Rescale(rescale, shift)
    hwc2chw_op = CV.HWC2CHW()
    type_cast_op = C.TypeCast(mstype.int32)

    # using map to apply operations to a dataset
    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label")
    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image")
    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image")
    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image")
    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image")

    # process the generated dataset
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    return mnist_ds

 训练模型

训练数据集和模型定义完成之后呢,我们就可以开始训练模型了。但是在训练之前,我们还需要从MindSpore导入两个函数:

  • 损失函数,也就是衡量预测结果和真实标签之间的差距的函数。看过上一章的同学可能会记得,我们之前用的损失函数是真实值与预测值之差的2-范数:

图8 2-范数损失

在这里,我们使用业界最常用的交叉熵损失函数SoftmaxCrossEntropyWithLogits,对于真实标签

和预测值,它们之间的交叉熵损失计算公式为:

其中J代表数组的下标,。从MindSpore导入损失函数:

from mindspore.nn import SoftmaxCrossEntropyWithLogits
# define the loss function
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 
  • 优化器,优化器就是用来求解损失函数关于模型参数的更新梯度的,它是整个训练过程中最重要的工具!我们这里用MindSpore提供的Momentum优化器:

import mindspore.nn as nn

lr = 0.01  # 定义学习率
momentum = 0.9  # 定义Momentum优化器的超参
# define the optimizer
net_opt = nn.Momentum(network.trainable_params(), lr, momentum)  # 导入mindspore提供

 准备好损失函数和优化器之后我们就可以开始训练模型了,也非常简单,我们先把前面定义好的模型、损失函数、优化器封装成一个Model:

from mindspore import Model
net = LeNet5()
model = Model(net, net_loss , net_opt , metrics={'acc', 'loss'})

然后使用model.train接口就可以训练我们定义的LeNet5模型了:

loss_cb = LossMonitor(per_print_times=ds_train.get_dataset_size())  # 用于监控训练过程中损失函数值的变化
ds_train = create_dataset(train_data_dir)  # 传入下载的训练集的路径
model.train(num_epochs, ds_train, callbacks=[loss_cb])  # num_epochs是训练的轮数,往往训练多轮才能使模型收敛

测试训练后的模型准确率

训练结束后,调用model.eval()计算训练后的模型在测试集上面的分类准确率:

ds_eval = create_dataset(test_data_dir)  # 传入下载的训练集的路径
metrics = model.eval(ds_eval)

小结

祝贺你耐心看完了MindSpore训练模型的完整过程,如果你想动手操作一遍,但是又没有现成的环境,那么你可以使用官网提供的“在线运行”来体验一番:

图9 MindSpore官网提供的免费体验入口

这是体验过程的实操视频:

https://zhuanlan.zhihu.com/p/463229660

欢迎投稿

欢迎大家踊跃投稿,有想投稿技术干货、项目经验等分享的同学,可以添加MindSpore官方小助手:小猫子(mindspore0328)的微信,告诉猫哥哦!

昇思MindSpore官方交流QQ群 : 486831414群里有很多技术大咖助力答疑!

MindSpore官方资料

GitHub : https://github.com/mindspore-ai/mindspore

Gitee : https : //gitee.com/mindspore/mindspore

官方QQ群 : 486831 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

手把手教你用MindSpore训练一个AI模型! 的相关文章

  • 使用 Marshmallow 中的数据更新行 (SQLAlchemy)

    我正在使用 Flask Flask SQLAlchemy Flask Marshmallow marshmallow sqlalchemy 尝试实现 REST api PUT 方法 我还没有找到任何使用 SQLA 和 Marshmallow
  • 如何配置散景图以具有响应宽度和固定高度

    我使用通过组件功能嵌入的散景 实际上我使用 plot sizing mode scale width 它根据宽度进行缩放并保持纵横比 但我想要一个响应宽度但固定或最大高度 这怎么可能实现呢 有stretch both and scale b
  • 刷新访问令牌时出现“invalid_grant”错误的情况?

    最近我一直在为这个问题揪心 一些背景 使用oauth2客户端 https code google com p google api python client 库来管理用户的令牌 这些令牌用于定期并发执行各种后台任务 每次要为用户运行其中一
  • 使用 Flask SQLAlchemy 进行表(模型)继承

    我遵循了这个建议question https stackoverflow com questions 1337095 sqlalchemy inheritance但我仍然收到此错误 sqlalchemy exc NoForeignKeysE
  • 01 无效令牌[重复]

    这个问题在这里已经有答案了 嘿 学习 python3有一段时间了 遇到字典和dictionary name get 方法并尝试获取随机键值 问题 data data get key 1 它有效并且返回 1 但如果我使用data get ke
  • 用于打印 C/C++ 文件的所有函数定义的 Python 脚本

    我想要一个 python 脚本来打印 C C 文件中定义的所有函数的列表 e g abc c定义两个函数为 void func1 int func2 int i printf d i return 1 我只想搜索文件 abc c 并打印其中
  • Python:如何重构循环导入

    我有件事可以帮你做engine setState
  • 如何通过 Python socket.send() 发送字符串以外的任何内容

    我对 Python 编程非常陌生 但出于必要 我必须快速地将一些东西组合在一起 我正在尝试通过 UDP 发送一些数据 除了当我执行 socket send 时 我必须以字符串形式输入数据之外 一切都正常 这是我的程序 这样你就可以看到我在做
  • Python3模拟用另一个函数替换函数

    如何使用 python 中的另一个函数来模拟一个函数 该函数也将提供一个模拟对象 我有类似以下操作的代码 def foo arg1 arg2 r bar arg1 does interesting things 我想替换的实现bar函数 让
  • Docker:通过 Gunicorn 运行 Flask 应用程序 - Worker 超时?表现不佳?

    我正在尝试创建一个用Python Flask编写的新应用程序 由gunicorn运行 然后进行dockerized 我遇到的问题是 docker 容器内的性能非常差 不一致 我最终得到了响应 但我不明白为什么性能会下降 有时我会在日志中看到
  • Python/Flask:应用程序在关闭后正在运行

    我正在开发一个简单的 Flask Web 应用程序 我使用 Eclipse Pydev 当我开发该应用程序时 由于代码更改 我必须经常重新启动该应用程序 这就是问题所在 当我运行该应用程序时 我可以在本地主机上看到该框架 这很好 但是当我想
  • Python将csv数据导出到文件中

    我有以下运行良好的代码 但我无法修剪数据并将其存储在数据文件中 import nltk tweets love this car this view amazing not looking forward the concert def g
  • 指定 Parquet 属性 pyspark

    如何在 PySpark 中指定 Parquet 块大小和页面大小 我到处搜索 但找不到任何有关函数调用或导入库的文档 根据火花用户档案 https mail archives apache org mod mbox spark user 2
  • 将带有两层分隔符的字符串转换为字典 - python

    给定一个字符串 s x t1 ny t2 nz t3 我想转换成字典 sdic x 1 y 2 z 3 我通过这样做让它工作 sdic dict tuple j split t for j in i for i in s split n F
  • 散景中的时间序列流

    我想在散景中绘制实时时间序列 我只想在每次更新时绘制新的数据点 我怎样才能做到这一点 散景网站上有一个动画情节的示例 但它每次都需要重新绘制整个图片 另外 我正在寻找一个简单的示例 我可以在其中逐点绘制时间序列的实时绘图 散景效果0 11
  • Scrapy - 不会爬行

    我正在尝试运行递归爬行 由于我编写的爬行不能正常工作 因此我从网络上提取了一个示例并进行了尝试 我真的不知道问题出在哪里 但是爬行没有显示任何错误 谁能帮我这个 另外 是否有任何逐步调试工具可以帮助理解蜘蛛的爬行流程 非常感谢任何与此相关的
  • 通过套接字发送字符串(python)

    我有两个脚本 Server py 和 Client py 我心中有两个目标 能够从客户端一次又一次地向服务器发送数据 能够将数据从服务器发送到客户端 这是我的 Server py import socket serversocket soc
  • 从 subprocess.Popen 获取整个输出

    我通过调用 subprocess Popen 得到了一个有点奇怪的结果 我怀疑这与我对 Python 的陌生有很大关系 args cscript USERPROFILE tools jslint js USERPROFILE tools j
  • 在 Gensim 中通过 ID 检索文档的字符串版本

    我正在使用 Gensim 进行一些主题建模 并且已经达到使用 LSI 和 tf idf 模型进行相似性查询的程度 我取回 ID 集和相似点 例如 299501 0 64505910873413086 如何获取与 ID 在本例中为 29950
  • Elastic Beanstalk 上的 Django + MySQL - 查询 MySQL 时出错

    当我在 Elastic beanstalk 上托管的 Django 应用程序上查询 MySQL 时 出现错误 错误说 admin login 处出现操作错误 1045 用户 adminDB 172 30 23 5 的访问被拒绝 使用密码 Y

随机推荐

  • 【Bootstrap】常用组件(框架)

    Bootstrap常用组件 目录 1 网格系统 Grid System 网格系统的工作原理 不同设备的尺寸定义与其对应类名 基本的网格结构 偏移列 2 Bootstrap 表格 3 容器container类 4 Bootstrap 按钮 5
  • Burp Suite配置代理

    1 打开burp工具后按照下图的步骤 2 点开Add后如下弹窗 输入端口号和地址后点击ok即可
  • 正则表达式基础语法大全

    正则表达式基础语法 1 普通字符 字母 数字 汉子 下划线 以及没有特殊定义的标点符号 都是 普通字符 表达式中的普通字符 在匹配一个字符串的时候 匹配与之相同的一个字符 2 简单的转义字符 3 标准字符集合 能够与 多种字符 匹配的表达式
  • 查看oracle数据库防火墙设置,用三个方法设置Oracle数据库穿越防火墙

    用三个方法设置Oracle数据库穿越防火墙 方法一 在系统注册表中 hkey local machinesoftwareoraclehome0下加入字符串值 USE SHARED SOCKET TRUE 方法二 1 首先 我们需要将数据库实
  • x264中open_file_yuv函数欣赏(顺便谈谈如何利用指针在被调函数中改变主调函数中变量的值)

    先来看一个结构体yuv input t typedef struct FILE fh int width height int next frame yuv input t yuv input t结构体用fh这个文件指针打开原始的yuv文件
  • 超详细的VSCode下载和安装教程(非常详细)从零基础入门到精通,看完这一篇就够了。

    文章目录 1 引言 2 下载VSCode 3 解决VSCode下载速度特别慢 4 安装VSCode 1 引言 今天用WebStorm运行前端代码时 发现不太好打断点 于是 打算改用VSCode来运行前端代码 但前提是要安装VSCode 如下
  • c,c++小白到大神系列教程之一:C语言入门-王健伟-专题视频课程

    c c 小白到大神系列教程之一 C语言入门 1127人已学习 课程介绍 本课程针对 有一点计算机基础比如知道二进制 八进制 十六进制数据的含义 对内存 堆 栈等有基本概念的计算机初学者 全面介绍C语言精华内容以及利用C语言进行程序设计的方法
  • 三角脉冲信号的表达式_【信号处理工具箱】—信号表示方法

    1 工具箱中常见的函数 1 sawtooth函数 sawtooth函数用于产生锯齿波或三角波信号 格式如下 t 0 0 0001 1 y sawtooth 2 pi 50 t subplot 211 plot t y axis 0 0 2
  • 用java做一个超级马里奥的小游戏

    好的 首先你需要准备一些基本的知识和工具 了解 Java 语言的基本语法和编程概念 安装好 Java 开发环境 比如 Eclipse 或者 IntelliJ IDEA 准备好一些图像和音频资源 用于游戏中的背景 角色 音效等元素 接下来 你
  • wazuh 收集 suricata eve.json日志

    安装suricata和规则 源码或者安装包 本博客提供安装包操作方式 切换成超级用户进行操作 yum y install epel release wget jq curl O https copr fedorainfracloud org
  • 2013豆瓣校园招聘研发类笔试题

    2013豆瓣校园招聘研发类笔试题 1 将一个递归算法改为对应的非递归算法时 通常需要使用 A 优先队列 B 队列 C 循环队列 D 栈 2 爸爸 妈妈 妹妹 小强 至少两个人同一生肖的概率是多少 A 41 96 B 55 96 C 72 1
  • qqkey获取原理_通过call获取qqkey支持最新版

    如果真 进程 是否存在 TIM exe 假 且 进程 是否存在 QQ exe 假 str 你还没有登录QQ 返回 0 如果真结束 如果真 进程 是否存在 QQ exe pid 进程 取同名ID QQ exe pids 计次循环首 pid i
  • python Web开发 flask轻量级Web框架

    O flask介绍 Flask是一个使用 Python 编写的轻量级 Web 应用框架 其 WSGI 工具箱采用 Werkzeug 模板引擎则使用 Jinja2 Flask使用 BSD 授权 Flask也被称为 microframework
  • 数据结构题目-稀疏矩阵

    目录 问题 AU 函数可变参数练习 附加代码模式 问题 AV 多维下标向一维下标的换算 问题 AW 稀疏矩阵类型判断 问题 AX 稀疏矩阵转换成简记形式 附加代码模式 问题 AY 根据三元组输出稀疏矩阵 问题 AZ 三元组法表示的稀疏矩阵
  • python3.7解决ModuleNotFoundError: No module named '_bz2'

    安装完python3 7之后运行一个软件提示错误 from bz2 import BZ2Compressor BZ2Decompressor ModuleNotFoundError No module named bz2 解决方法如下 一
  • Linux内存管理子系统

    1 Linux子系统 Linux内核组成 SCI系统调用接口 PM进程管理子系统 MM内存管理子系统 Arch体系结构相关代码 DD驱动程序 Network Stack网络协议站 VFS虚拟文件系统 DD驱动程序 2 Linux内存管理子系
  • 《Apache MINA 2.0 用户指南》第十二章:日志过滤器

    后台 用户开放基于Apache MiNa的应用程序 用户可以在应用程序中创建日志管理 SLF4J MINa采用SLF4j作为日志输出 你可以在这里发现很多关于SLF4j的相关介绍 这个日志工具允许任何形式的日志系统实施 你可能使用 log4
  • 手把手教你使用gtest写单元测试

    开源框架 gtest 它主要用于写单元测试 检查真自己的程序是否符合预期行为 这不是QA 测试工程师 才学的 也是每个优秀后端开发codoer的必备技能 本期博文内容及使用的demo 参考 Googletest Basic Guide 1
  • 设计模式——多线程下的懒汉式单例

    懒汉 模式虽然有优点 但是每次调用 GetInstance 静态方法时 必须判断NULL m instance 使程序相对开销增大 多线程中会导致多个实例的产生 从而导致运行代码不正确以及内存的泄露 对于多线程的问题 我们可以看下面这个例子
  • 手把手教你用MindSpore训练一个AI模型!

    首先我们要先了解深度学习的概念和AI计算框架的角色 https zhuanlan zhihu com p 463019160 本篇文章将演示怎么利用MindSpore来训练一个AI模型 和上一章的场景一致 我们要训练的模型是用来对手写数字图