matlab2019a中LSTM网络使用方法及源码示例(Deep Learning Toolbox系列篇6)

2023-11-18

此示例说明如何使用长短期记忆 (LSTM) 网络对序列数据进行分类。

要训练深度神经网络以对序列数据进行分类,可以使用 LSTM 网络。LSTM 网络允许您将序列数据输入网络,并根据序列数据的各个时间步进行预测。

此示例使用 [1] 和 [2] 中所述的日语元音数据集。此示例训练一个 LSTM 网络,旨在根据表示连续说出的两个日语元音的时序数据来识别说话者。训练数据包含九个说话者的时序数据。每个序列有 12 个特征,且长度不同。该数据集包含 270 个训练观测值和 370 个测试观测值。

源码:


%% 通用matlab脚本三连
clear
clc
close all

%% 加载序列数据
% 加载日语元音训练数据。XTrain 是包含 270 个不同长度的 12 维序列的元胞数组。
% Y 是对应于九个说话者的标签 "1"、"2"、...、"9" 的分类向量。
% XTrain 中的条目是具有 12 行(每个特征一行)和不同列数(每个时间步一列)的矩阵。
[XTrain,YTrain] = japaneseVowelsTrainData;
XTrain(1:5)

%% 在绘图中可视化第一个时序。每行对应一个特征。
figure
plot(XTrain{1}')
xlabel("Time Step")
title("Training Observation 1")
legend("Feature " + string(1:12),'Location','northeastoutside')

%% 准备要填充的数据
numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,2);
end

%% 按序列长度对数据进行排序。
[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);

%% 在条形图中查看排序的序列长度。
figure
bar(sequenceLengths)
ylim([0 30])
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")

%% 选择小批量大小 27 以均匀划分训练数据,并减少小批量中的填充量。下图说明了添加到序列中的填充。
miniBatchSize = 27;

%% 定义 LSTM 网络架构
inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(inputSize)
    bilstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]

maxEpochs = 100;
miniBatchSize = 27;

%% 指定训练选项
options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'GradientThreshold',1, ...
    'MaxEpochs',maxEpochs, ...
    'MiniBatchSize',miniBatchSize, ...
    'SequenceLength','longest', ...
    'Shuffle','never', ...
    'Verbose',0, ...
    'Plots','training-progress');

%% 训练 LSTM 网络
net = trainNetwork(XTrain,YTrain,layers,options);

%% 测试 LSTM 网络
[XTest,YTest] = japaneseVowelsTestData;
XTest(1:3)

%% LSTM 网络 net 已使用相似长度的小批量序列进行训练
numObservationsTest = numel(XTest);
for i=1:numObservationsTest
    sequence = XTest{i};
    sequenceLengthsTest(i) = size(sequence,2);
end
[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
YTest = YTest(idx);

%% 对测试数据进行分类
miniBatchSize = 27;
YPred = classify(net,XTest, ...
    'MiniBatchSize',miniBatchSize, ...
    'SequenceLength','longest');

%% 计算预测值的分类准确度。
acc = sum(YPred == YTest)./numel(YTest)

运行结果:

本文着重讲解一下matlab的深度学习训练方式

在matlab中,训练一个网络有一个常用的函数trainNetwork;此函数通常有四个参数

%% 训练 LSTM 网络
net = trainNetwork(XTrain,YTrain,layers,options);

 其中,XTrain为训练的输入数据集,YTrain为训练数据对应的标签;

layers为网络的架构,也就是指网络每层的处理模式。

options为超参数的设置,包括学习率,优化方法,迭代次数以及批量大小问题等。

1. 构建网络模型;

inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(inputSize) %输入层
    bilstmLayer(numHiddenUnits,'OutputMode','last') %第一层隐层,LSTM架构
    fullyConnectedLayer(numClasses) % 第二层隐层,全连接层
    softmaxLayer %softmax处理
    classificationLayer] %分类层

 如上代码片段所示,逐行对每一层的网络处理模式进行说明,并将此构成的数组形式赋予一个变量(此中为layers)。

2. 指定训练的超参数

%% 指定训练选项
options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'GradientThreshold',1, ...
    'MaxEpochs',maxEpochs, ...
    'MiniBatchSize',miniBatchSize, ...
    'SequenceLength','longest', ...
    'Shuffle','never', ...
    'Verbose',0, ...
    'Plots','training-progress');

其中参数设定的格式必须由trainingOptions进行格式打包,即options的数据类型必须为TrainingOptions派,此示例中options的数据类型为TrainingOptionsADAM。另外本示例数据集方面的XTrain与XTest都是matlab中的cell数据类型。

 更详细的matlab深度学习训练方法(trainNetwork用法)请参考matlab官方文档或本博客Deep Learning ToolBox系列7。

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

matlab2019a中LSTM网络使用方法及源码示例(Deep Learning Toolbox系列篇6) 的相关文章

  • 读取 MEX 文件中的 4D 数组

    我在 MATLAB 中有一个 4 维数组 我正在尝试访问 MEX 函数中的数组 下面创建 testmatrix 一个 4 维矩阵 已知数据为uint8 type Create a 4D array 2x 2y rgb 3 framenumb
  • 我的傅立叶逆变换中的尖峰

    我正在尝试在 MATLAB 中比较两个数据集 为此 我需要通过傅里叶变换数据来过滤数据集 对其进行过滤 然后对其进行逆傅里叶变换 然而 当我对数据进行逆傅里叶变换时 我在红色数据集的两端都出现了一个尖峰 图片显示了第一个尖峰 它在开始时应该
  • 如何从 Matlab 在 vi​​rtualenv 中执行 Python 代码

    我正在创建一个用于研究的 Matlab 工具箱 我需要执行 Matlab 代码 但也需要执行 Python 代码 我想允许用户从 Matlab 执行 Python 代码 问题是 如果我立即执行此操作 我将必须在 Python 环境中安装所有
  • 从彩色背景中提取黑色对象

    人眼很容易辨别black来自其他颜色 但是计算机呢 我在普通的A4纸上打印了一些色块 由于组成彩色图像有青色 品红色和黄色三种墨水 所以我设置每个块的颜色C 20 C 30 C 40 C 50 以及其余两种颜色是 0 这是我的源图像的第一列
  • 我的 matlab 图中需要不同的颜色

    这是我的情节代码 问题是我的图中的两条线具有相同的颜色 我需要为图中的每条线 总共 4 条线 分配一个特殊的颜色 for i 1 nFolderContents data hdrload folderContents i if size f
  • 有没有办法在 MATLAB 中执行函数内联?

    我可以使用什么语言功能或开箱即用的技巧来完成 MATLAB 中的函数内联 令人烦恼的是 Google 搜索 matlab 内联函数 http www google com search q matlab inline function揭示了
  • 如何打开 matlab p 代码文件

    有谁知道如何查看 matlab p 代码文件的代码 p 代码文件专门存在 以便您可以共享代码 以便其他人无法查看它 换句话说 您看不到 Matlab p 代码文件的代码
  • 带 if 语句的可向量化 FIND 函数 MATLAB

    我有一个矩阵u 我想遍历所有行和所有列并执行以下操作 如果元素非零 我返回行索引的值 如果元素为零 则查找该元素之后的下一个非零元素的行索引 我可以使用两个带有 find 函数的 for 循环轻松完成此操作 但我需要多次执行此操作 不是因为
  • MATLAB 链表

    有哪些可能的方法来实现链表MATLAB http en wikipedia org wiki MATLAB 注意 我问这个问题是为了教学价值 而不是实用价值 我意识到 如果您实际上在 MATLAB 中滚动自己的链表 那么您可能做错了什么 然
  • 如何检测图像中对象的实例?

    我有一张包含几个特定对象的图像 我想检测这些物体在该图像中的位置 为此 我有一些模型图像 其中包含我想要检测的对象 这些图像在我想要检测的对象实例周围得到了很好的裁剪 这是一个例子 在这张大图里 我想检测此模型图像中表示的对象 自从你最初发
  • 将 Android 应用程序与服务器上的 Matlab 应用程序连接

    我正在 Android 上开发一个应用程序 它将获取图像输入 并将该输入传递到安装 MATLAB 应用程序的服务器 MATLAB 应用程序将计算结果并将其返回到该 Android 应用程序 我想知道我可以使用哪个服务器 如何将 MATLAB
  • 在matlab中设置图例符号的精度

    我有这个 leg2 strcat Max Degree num2str adet 1 1 ch l leg3 strcat Min Degree num2str adet 1 2 ch l leg4 strcat Max Request n
  • 如何使用Matlab提高PSD的分辨率

    我有音频信号 我用 Matlab 读取该信号 并使用 pwelch 获取其 PSD 这是我正在使用的代码 x Fs audioread audioFile wav x x 1 mono xPSD f pwelch x hamming 512
  • 使用 python 在网络上部署 matlab 应用程序

    您好 我想使用 python 在网络上部署 matlab 应用程序 有没有办法做到这一点 我已按照数学工作网站上的文档将我的应用程序转换为 jar 文件 java 类 有人能指出我前进的正确方向吗 事实上 您的 Matlab 代码打包为 J
  • Microsoft Visual C++ 2008 和 R2007b 的 Mex 类型

    我想对 vs2008 和 matlab2007b 使用 mex 类型 我尝试了下面的代码 include
  • 查找数组中元素之间的平均差异的有效方法

    希望标题不会让人困惑 通过例子来展示很简单 我有一个像这样的行向量 1 5 6 我想找到每个元素之间的平均差异 此示例中的差异为 4 和 1 因此平均值为 2 5 这是一个小例子 我的行向量可能非常大 我是 MatLab 新手 那么有没有一
  • 如何读取 10 位原始图像?其中包含 RGB-IR 数据

    我想知道如何从我的 10 位原始 它有 rgb ir 图像数据 数据中提取 RGB 图像 如何使用 Python 或 MATLAB 进行阅读 拍摄时的相机分辨率为 1280x720 室内照片图片下载 https drive google c
  • 将 3d 矩阵重塑为 2d 矩阵

    我有一个 3d 矩阵 n by m by t 在 MATLAB 中表示n by m一段时间内网格中的测量值 我想要一个二维矩阵 其中空间信息消失了 只有n m随着时间的推移测量t剩下 即 n m by t 我怎样才能做到这一点 你需要命令r
  • 是否有一个函数可以检查矩阵是否对角占优(行占优)

    矩阵是对角占优 http en wikipedia org wiki Diagonally dominant matrix 按行 如果对角线处的值在绝对意义上大于该行中所有其他绝对值的总和 对于列也是如此 只是相反 matlab中有没有函数
  • 通过傅里叶空间填充进行插值

    我最近尝试在 matlab 上实现一个在傅立叶域中使用零填充的插值方法的简单示例 但我无法正常工作 我总是有一个小的频移 在傅里叶空间中几乎不可见 但它在时空上产生了巨大的误差 由于傅里叶空间中的零填充似乎是一种常见 且快速 的插值方法 因

随机推荐