多维时序

2023-10-27

多维时序 | MATLAB实现Attention-LSTM(注意力机制长短期记忆神经网络)多输入单输出

基本介绍

本次运行测试环境MATLAB2020b;
文章针对LSTM 存在的局限性,提出了将Attention机制结合LSTM 神经网络的预测模型。采用多输入单输出回归预测,再将attention 机制与LSTM 结合作为预测模型,使预测模型增强了对关键时间序列的注意力。

模型背景

  • 由于LSTM 神经网络具有保存历史信息的功能,在处理长时间序列输入时相较于传统神经网络更为有效,于近几年取得了广泛的应用。长短时记忆神经网络最早是由Hochreite 和Schmidhuber 提出。
  • LSTM 神经网络克服了传统循环神经网络中存在的难以解决长期依赖性以及梯度消失和爆炸的问题。
  • 然而,采用传统的编码-解码器的LSTM模型在对输入序列学习时,模型会先将所有的输入序列编码成一个固定长度的向量,而解码过程则受限于该向量的表示,这也限制了LSTM 模型的性能。
  • 文章针对LSTM 存在的局限性,提出了将Attention机制结合LSTM 神经网络的预测模型,将attention 机制与LSTM 结合作为预测模型,使预测模型增强了对关键时间序列的注意力。

LSTM模型

  • 传统循环神经网络( RNN) 对短时间序列输入比较敏感,处理短时间序列的表现较好。但是当存在长期输入时,传统RNN 会出现在某一时刻之前的所有隐藏层状态在训练中都不会影响到权重数组W 的更新的情况。这就是所谓的梯度消失问题。
  • 由于传统RNN 神经网络在实际的应用中,存在着梯度爆炸或梯度消失的问题,因此传统RNN 不适合于解决长序列问题。为了解决RNN 存在的缺陷而出现的LSTM 神经网络近年来在语音识别、语言翻译以及图像处理等方面取得了广泛的应用,相较于传统循环神经网络( RNN) 有着其特别的优势。长短时记忆网络在原有RNN 的基础上,在隐藏层中额外加入了一个可以保存长期状态的单元C。LSTM单元内部结构如图1 所示。
    1
  • 与前馈神经网络类似,LSTM 网络的训练同样采用的是误差的反向传播算法( Back-propagation) ,因为LSTM 处理的是序列数据,所以在使用误差反向传播算法的时候需要将整个时间序列上的误差传播回来。当前LSTM 单元的状态会受到前一时刻LSTM 单元状态的影响。
  • 同时在误差反向传播计算时隐含层ht的误差不仅仅包含当前时刻t 的误差,也包括t 时刻之后所有时刻的误差,这就是误差基于时间反向传播算法的含义。

Attention-LSTM 模型

  • 传统的编码- 解码器( Encoder-Decoder) 模型在处理输入序列时,编码器Encoder 将输入序列Xt编码成固定长度的隐向量h,对隐向量赋予相同的权重。
  • 解码器Decoder 基于隐向量h 解码输出。当输入序列的长度增加时,分量的权重相同,模型对于输入序列Xt没有区分度,造成模型性能下降。
  • Attention 机制解决了此问题,Attention 是一种用于提升编码-解码模型效果的机制,其本质是模仿人在观察东西时大脑的思维活动。当某个场景经常在其中一部分出现重要的东西时,人脑就会进行学习,之后看到类似场景时注意力就会集中到该部分上。
  • 使模型对输入序列的不同时刻隐向量h 赋予了相对应的权重,按重要程度将隐向量合并为新的隐向量并输入到解码器Decoder。加入Attention 机制的Encoder-Decoder 模型如图2 所示。
    2
  • 标准的LSTM 采用的是传统编码- 解码器结构。输入LSTM 的数据序列无论长短都被编码成固定长度的向量表示。虽然LSTM 的记忆功能可以保存长期状态,但是在实际应用过程中,面对庞大的多维度,多变量数据集时不能很好地加以处理,在训练时模型可能会忽略某些重要的时序信息,造成模型的性能变差,影响预测精度。
  • 针对LSTM 自身存在的缺陷,文章在LSTM 的基础上引入了Attention 机制,目的是为了打破传统编码-解码器在编码过程中使用固定长度向量的限制,保留LSTM 编码器的中间状态,通过训练模型来对这些中间
    状态进行选择性学习。
  • 结合了Attention 机制的LSTM功率预测模型能够判断各输入时刻信息的重要程度,模型的训练效率得以提高。Attention 机制通过对LSTM 的输入特征赋予了不同的权重,突出了关键的影响因素,帮助LSTM 做出准确的判断,而且不会增加模型的计算和存储开销。
    3

程序设计

%% Attention_LSTM
% 数据集,列为特征,行为样本数目
clc
clear
close all
% 导入数据
load('./data.mat')
data(1,:) =[];
% 训练集
y = data.demand(1:1000);
x = data{1:1000,3:end};
[xnorm,xopt] = mapminmax(x',0,1);
[ynorm,yopt] = mapminmax(y',0,1);
x = x';
xnorm = xnorm(:,1:1000);
ynorm = ynorm(1:1000);
% 滞后长度
k = 24;           
% 转换成2-D image
for i = 1:length(ynorm)-k
    Train_xNorm(:,i,:) = xnorm(:,i:i+k-1);
    Train_yNorm(i) = ynorm(i+k-1);
end
Train_yNorm= Train_yNorm';
% 测试集
ytest = data.demand(1001:1170);
xtest = data{1001:1170,3:end};
[xtestnorm] = mapminmax('apply', xtest',xopt);
[ytestnorm] = mapminmax('apply',ytest',yopt);
xtest = xtest';
for i = 1:length(ytestnorm)-k
    Test_xNorm(:,i,:) = xtestnorm(:,i:i+k-1);
    Test_yNorm(i) = ytestnorm(i+k-1);
    Test_y(i) = ytest(i+k-1);
end
Test_yNorm = Test_yNorm';

clear k i x y
% 自定义训练循环的深度学习数组
Train_xNorm = dlarray(Train_xNorm,'CBT');
Train_yNorm = dlarray(Train_yNorm,'BC');
Test_xNorm = dlarray(Test_xNorm,'CBT');
Test_yNorm = dlarray(Test_yNorm,'BC');
% 训练集和验证集划分
TrainSampleLength = length(Train_yNorm);
validatasize = floor(TrainSampleLength * 0.1);
Validata_xNorm = Train_xNorm(:,end - validatasize:end,:);
Validata_yNorm = Train_yNorm(:,end-validatasize:end,:);
Train_xNorm = Train_xNorm(:,1:end-validatasize,:);
Train_yNorm = Train_yNorm(:,1:end-validatasize,:);
%% 参数设定
%数据输入x的特征维度
inputSize = size(Train_xNorm,1); 
%数据输出y的维度 
outputSize = 1;  
numhidden_units=50;
% 导入初始化参数
[params,~] = paramsInit(numhidden_units,inputSize,outputSize);    
[~,validatastate] = paramsInit(numhidden_units,inputSize,outputSize);
[~,TestState] = paramsInit(numhidden_units,inputSize,outputSize);     
% 训练相关参数
TrainOptions;
numIterationsPerEpoch = floor((TrainSampleLength-validatasize)/minibatchsize);
LearnRate = 0.01;
%% 迭代更新
figure
start = tic;
lineLossTrain = animatedline('color','b');
validationLoss = animatedline('color','r','Marker','o');
xlabel("Iteration")
ylabel("Loss")
% epoch 更新 
iteration = 0;
for epoch = 1 : numEpochs   
    [~,state] = paramsInit(numhidden_units,inputSize,outputSize);       % 每轮epoch,state初始化
    disp(['Epoch: ', int2str(epoch)])   
    % batch 更新
    for i = 1 : numIterationsPerEpoch      
        iteration = iteration + 1;
        disp(['Iteration: ', int2str(iteration)])
        idx = (i-1)*minibatchsize+1:i*minibatchsize;
        dlX = gpuArray(Train_xNorm(:,idx,:));
        dlY = gpuArray(Train_yNorm(idx));
        [gradients,loss,state] = dlfeval(@ModelD,dlX,dlY,params,state);       
        % L2正则化
        L2regulationFactor = 0.001;        
        [params,averageGrad,averageSqGrad] = adamupdate(params,gradients,averageGrad,averageSqGrad,iteration,LearnRate);               
        % 验证集测试
        if iteration == 1 || mod(iteration,validationFrequency) == 0
            output_Ynorm = ModelPredict(gpuArray(Validata_xNorm),params,validatastate);
            lossValidation = mse(output_Ynorm, gpuArray(Validata_yNorm));
        end       
        % 作图(训练过程损失图)
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
        if iteration == 1 || mod(iteration,validationFrequency) == 0
            addpoints(validationLoss,iteration,double(gather(extractdata(lossValidation))))
        end
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        legend('训练集','验证集')
        drawnow       
    end    
    % 每轮epoch 更新学习率
    if mod(epoch,10) == 0
        LearnRate = LearnRate * LearnRateDropFactor;
    end
end
  • 预测效果:
    4
    5

参考资料

[1] https://blog.csdn.net/kjm13182345320/article/details/120406657?spm=1001.2014.3001.5501
[2] https://blog.csdn.net/kjm13182345320/article/details/120377303?spm=1001.2014.3001.5501

致谢

  • 大家的支持是我写作的动力!
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

多维时序 的相关文章

  • 我的2016--"狗血"

    偶然看到了CSDN的 我的2016 主题征文活动 突然感慨一番 今年又快结束了 而我这一年的经历 可以浓缩为两个字 狗血 然而 我能用上如此不羁的词汇 并未能掩盖我木讷的内心 这才真的是狗血 感觉像在梦游 走了好远的路 一睁开眼睛却还在原地

随机推荐

  • 【100%通过率 】【华为OD机试 c++/java/python】任务总执行时长【 2023 Q1

    华为OD机试 题目列表 2023Q1 点这里 2023华为OD机试 刷题指南 点这里 题目描述 任务总执行时长 任务编排服务负责对任务进行组合调度 参与编排的任务有两种类型 其中一种执行时长为taskA 另一种执行时长为taskB 任务一旦
  • 2023华为产品测评官-开发者之声

    2023华为产品测评官 开发者之声 活动激发了众多开发者和技术爱好者的热情 他们纷纷递交了精心编写的产品测评报告 活动社群充满活力 参与者们热衷于交流讨论 互相帮助解决问题 一起探索云技术的无限可能 在此次活动中 华为云CodeArts获得
  • 怎么用计算机算ess tss,"ESS、RSS、TSS"分别表示什么?

    回归平方和 ESS 残差平方和 RSS 总体平方和 TSS 1 回归平方和 是反映自变量与因变量之间的相关程度的偏差平方和 用回归方程或回归线来描述变量之间的统计关系时 实验值yi与按回归线预测的值Yi并不一定完全一致 2 残差平方和是在线
  • ChatGPT 再遭禁用

    近日 三星电子宣布禁止员工使用流行的生成式AI工具 原因在于4月初三星内部发生的三起涉及 ChatGPT 误用造成的数据泄露事件 报道称 三星半导体设备测量资料 产品良率等内容或已被存入ChatGPT学习资料库中 去年11月上线以来 Cha
  • 超高清

    海思 HDR HDR行业面临巨大挑战 01 标准不统一 终端呈现效果参差不齐 HDR多种技术标准共存 缺少终端侧技术实现方案 标准间兼容性较差 不能覆盖主流终端的适配 认证及测试过程 导致终端呈现效果差距大 02 生态碎片化 部分技术方案专
  • Android系统开发之修改Captive Potal Service(消灭感叹号)

    本文原作者 长鸣鸟 未经同意 转载不带名的严重鄙视 谷歌在Android5 0之后的版本加入了CaptivePotalLogin服务 本服务的功能是检查网络连接互联网情况 主要针对于Wi Fi 不让Android设备自动连接那些不能联网的无
  • Visio 2007/2010 左侧"形状"窗口管理

    Visio 2007 2010 左侧 形状 窗口管理 Visio 打开后 通常窗口左侧会有一个 形状 面板 我们可以方便地从中选择需要的形状 有时为了获得更大的版面空间或者不小心关闭了形状面板 怎么把它重新调出来 我们可以从 视图 中把它找
  • 代码随想录算法训练营第三天

    今天是算法训练营的第三天 写了454 四数相加 II这道题目 力扣链接 代码随想录链接 代码如下 class Solution def fourSumCount self nums1 List int nums2 List int nums
  • 独家

    随机森林 概述 当变量的数量非常庞大时 你将采取什么方法来处理数据 通常情况下 当问题非常庞杂时 我们需要一群专家而不是一个专家来解决问题 例如Linux 它是一个非常复杂的系统 因此需要成百上千的专家来搭建 以此类推 我们能否将许多专家的
  • 【华为OD统一考试B卷

    在线OJ 已购买本专栏用户 请私信博主开通账号 在线刷题 运行出现 Runtime Error 0Aborted 请忽略 华为OD统一考试A卷 B卷 新题库说明 2023年5月份 华为官方已经将的 2022 0223Q 1 2 3 4 统一
  • 动态规划(五)

    01背包问题 01 Knapsack problem 有10件货物要从甲地运送到乙地 每件货物的重量 单位 吨 和利润 单位 元 如下表所示 由于只有一辆最大载重为30t的货车能用来运送货物 所以只能选择部分货物配送 要求确定运送哪些货物
  • Matplotlib

    1 折线图 import matplotlib pyplot as plt import numpy as np x np linspace 1 1 50 1到1 有五十个点 y 2 x 1 plt figure num 1 figsize
  • 第1课:三位一体定位法,让写作事半功倍

    做最懂技术的传播者 最懂传播的工程师 课程内容分析 本课程的目标是 通过对一系列问题的梳理 找到适合自己的输出状态 确定与理想输出状态之间存在的差距 以及采取什么办法 减少差距 知识要点 1 受众需要什么 省时间的内容 收敛 看过就走 教你
  • java错误-The prefix "aop" for element "aop:aspectj-autoproxy" is not bound.

    配置springmvc的aop时出错 当我向配置文件中添加
  • 年底裁员潮,你有没有被"N+1"?

    2018年11月28日上午 前一天加班到深夜的李女士 又一大早起床匆匆赶去上班了 她在一家垂直电商公司工作多年 岁末将至 一切和往常一样 为了在年前完成比上一季度更高的 KPI 她所在团队经常通宵达旦赶工 李女士准备开始新一天的鸡血工作 主
  • 数学甜点004

    数学是一门及其高深又变幻莫测的学科 且其根本就是问题的解决 因此是不可能也没有必要去寻找一种能够解决所有问题的通解的 坦白说 研究数学的最大乐趣就是在于发现从来没有人走过的新道路 即一种不同于常规的具有跳跃性 构造性的解法 换句话说 无论是
  • 时序预测

    时序预测 MATLAB实现AR时间序列预测 目录 时序预测 MATLAB实现AR时间序列预测 基本介绍 程序设计 学习总结 参考资料 基本介绍 如果某个时间序列的任意数值可以表示自回归方程 那么该时间序列服从p阶的自回归过程 可以表示为AR
  • 你需要知道面试中的10个JavaScript概念

    翻译原文出处 10 JavaScript concepts you need to know for interviews 之前不是闹得沸沸扬扬的大漠穷秋文章 为什么只会Vue的都是前端小白 甚至大多数回头看了 也就会jQuery和Vue这
  • AI绘画

    今天用Midjourney生成了质量极高的美少女武士后续会作为固定栏目来分享美图接下来请欣赏作品 提示词分享 1 an asian girl dressed in samurai style in the style of anime ae
  • 多维时序

    多维时序 MATLAB实现Attention LSTM 注意力机制长短期记忆神经网络 多输入单输出 目录 多维时序 MATLAB实现Attention LSTM 注意力机制长短期记忆神经网络 多输入单输出 基本介绍 模型背景 LSTM模型