基于自注意力机制的LSTM多变量负荷预测

2023-11-11

1.引言 

       在之前使用长短期记忆网络构建电力负荷预测模型的基础上,将自注意力机制 (Self-Attention)融入到负荷预测模型中。具体内容是是在LSTM层后面接Self-Attention层,在加入Self-Attention后,可以将负荷数据通过加权求和的方式进行处理,对负荷特征添加注意力权重,来突出负荷的影响因数。结果表明,通过自注意力机制,可以更好的挖掘电力负荷数据的特征以及变化规律信息,提高预测模型的性能。

        环境:python3.8,tensorflow2.5.

2.原理

2.1.自注意力机制

        自注意力机制网上很多推导,这里就不再赘述,需要的可以看博客,这个博客讲的很好。

2.2 模型结构

主要包含输入层,LSTM层,位置编码层,自注意力机制层,以及输出层。

3. 实战

3.1 数据结构

        采用2016电工杯负荷预测数据,每15分钟采样一次,一天共96个负荷值与5个气象数据(温度湿度降雨量啥的)。我们采用滚动建模预测,就是利用1到n天的所有值为输入,第n+1天的96个负荷值为输出;然后2到n+1天的所有值为输入,第n+2天的96个负荷值为输出,这样进行滚动序列建模。这个n就是时间步,程序里面设置的是20,所以上面的输入层你看到是Nonex20x101,输出是Nonex96。

3.2 建模预测

# coding: utf-8
from sklearn.preprocessing import StandardScaler,MinMaxScaler
from sklearn.metrics import r2_score
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input,Dense,LSTM
import tensorflow as tf
from Layers import SelfAttention,AddSinusoidalPositionalEncodings
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/'
from tensorflow.keras.utils import plot_model

# In[]定义一些需要的函数

def build_model(seq,fea,out):
    input_ = Input(shape=(seq,fea))
    x=LSTM(20, return_sequences=True)(input_)    
    pos = AddSinusoidalPositionalEncodings()(x)
    att = SelfAttention(100,100,return_sequence=False, dropout=.0)(pos)
    
    out = Dense(out, activation=None)(att)
    model = Model(inputs=input_, outputs=out)
    return model


def split_data(data, n):
    in_ = []
    out_ = []    
    N = data.shape[0] - n
    for i in range(N):
        in_.append(data[i:i + n,:])
        out_.append(data[i + n,:96])
    in_ = np.array(in_).reshape(len(in_), -1)
    out_ = np.array(out_).reshape(len(out_), -1)
    return in_, out_
def result(real,pred,name):
    # ss_X = MinMaxScaler(feature_range=(-1, 1))
    # real = ss_X.fit_transform(real).reshape(-1,)
    # pred = ss_X.transform(pred).reshape(-1,)
    real=real.reshape(-1,)
    pred=pred.reshape(-1,)
    # mape
    test_mape = np.mean(np.abs((pred - real) / real))
    # rmse
    test_rmse = np.sqrt(np.mean(np.square(pred - real)))
    # mae
    test_mae = np.mean(np.abs(pred - real))
    # R2
    test_r2 = r2_score(real, pred)

    print(name,'的mape:%.4f,rmse:%.4f,mae:%.4f,R2:%.4f'%(test_mape ,test_rmse, test_mae, test_r2))

# In[]
df=pd.read_csv('数据集/data196.csv').fillna(0).iloc[:,1:]
data=df.values
time_steps=20
in_,out_=split_data(data,time_steps)

n=range(in_.shape[0])
#m=int(0.8*in_.shape[0])#前80%训练 后20%测试
m=-2#最后两天测试
train_data = in_[n[0:m],]
test_data = in_[n[m:],]
train_label = out_[n[0:m],]
test_label = out_[n[m:],]

# 归一化
ss_X = StandardScaler().fit(train_data)
ss_Y = StandardScaler().fit(train_label)
# ss_X = MinMaxScaler(feature_range=(0, 1)).fit(train_data)
# ss_Y = MinMaxScaler(feature_range=(0, 1)).fit(train_label)
train_data = ss_X.transform(train_data).reshape(train_data.shape[0], time_steps, -1)
train_label = ss_Y.transform(train_label)

test_data = ss_X.transform(test_data).reshape(test_data.shape[0], time_steps, -1)
test_label = ss_Y.transform(test_label)
# In[]
model=build_model(train_data.shape[-2],train_data.shape[-1],train_label.shape[-1])
#查看网络结构
model.summary()
plot_model(model, show_shapes=True, to_file='result/lstmsa_model.jpg')

train_again=True  #为 False 的时候就直接加载训练好的模型进行测试
#训练模型
if train_again:
    #编译模型
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse')
    #训练模型
    history=model.fit(train_data,train_label,batch_size=64,epochs=100,
                      verbose=1,validation_data=(test_data,test_label))
    # In[8]
    model.save_weights('result/lstmsa_model.h5')
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    plt.plot( loss, label='Train Loss')
    plt.plot( val_loss, label='Test Loss')
    plt.title('Train and Val Loss')
    plt.legend()
    plt.savefig('result/lstmsa_model_loss.jpg')
    plt.show()
else:#加载模型
    model.load_weights('result/lstmsa_model.h5')

# In[]
test_pred = model.predict(test_data)

# 对测试集的预测结果进行反归一化
test_label1 = ss_Y.inverse_transform(test_label)
test_pred1 = ss_Y.inverse_transform(test_pred)
# In[]计算各种指标
result(test_label1,test_pred1,'LSTM-SA')
np.savez('result/lstmsa1.npz',real=test_label1,pred=test_pred1)

test_label=test_label1.reshape(-1,)
test_pred=test_pred1.reshape(-1,)
# plot test_set result
plt.figure()
plt.plot(test_label, c='r', label='real')
plt.plot(test_pred, c='b', label='pred')
plt.legend()
plt.xlabel('样本点')
plt.ylabel('功率')
plt.title('测试集')
plt.show()

 3.2 结果对比

        将其与RNN、LSTM进行对比,结果如下

测试集取的是最后两天的,从结果上看,显然提出的方法效果最好 

4.代码

        详细代码见评论区。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于自注意力机制的LSTM多变量负荷预测 的相关文章

  • 如何快速申请GPT账号?

    详情点击链接 如何快速申请GPT账号 一OpenAI 1 最新大模型GPT 4 Turbo 2 最新发布的高级数据分析 AI画图 图像识别 文档API 3 GPT Store 4 从0到1创建自己的GPT应用 5 模型Gemini以及大模型
  • 手把手教你用 Stable Diffusion 写好提示词

    Stable Diffusion 技术把 AI 图像生成提高到了一个全新高度 文生图 Text to image 生成质量很大程度上取决于你的提示词 Prompt 好不好 前面文章写了一篇文章 一份保姆级的 Stable Diffusion
  • 主流进销存系统有哪些?企业该如何选择进销存系统?

    主流进销存系统有哪些 企业该如何选择进销存系统 永久免费 的软件 这个可能还真不太可能有 而且就算有 也只能说是相对免费 因为要么就是数据存量有限 要么就是功能有限 数据 信息都不保障 并且功能不完全 免费 免费软件 免费进销存 诸如此类
  • 蒙特卡洛在发电系统中的应用(Matlab代码实现)

    欢迎来到本博客 博主优势 博客内容尽量做到思维缜密 逻辑清晰 为了方便读者 座右铭 行百里者 半于九十 本文目录如下 目录 1 概述 2 运行结果 3 参考文献 4 Matlab代码实现
  • 史上最全自动驾驶岗位介绍

    作者 自动驾驶转型者 编辑 汽车人 原文链接 https zhuanlan zhihu com p 353480028 点击下方 卡片 关注 自动驾驶之心 公众号 ADAS巨卷干货 即可获取 点击进入 自动驾驶之心 求职交流 技术交流群 本
  • 15天学会Python深度学习,我是如何办到的?

    陆陆续续有同学向我们咨询 Python编程如何上手 深度学习怎么学习 如果有人能手把手 一对一帮帮我就好了 我们非常理解初学者的茫然和困惑 大量视频 书籍 广告干扰了大家的判断 学习Python和人工智能 成为内行人不难 为此 我们推出了
  • 考虑光伏出力利用率的电动汽车充电站能量调度策略研究(Matlab代码实现)

    欢迎来到本博客 博主优势 博客内容尽量做到思维缜密 逻辑清晰 为了方便读者 座右铭 行百里者 半于九十 本文目录如下 目录 1 概述 2 运行结果 3 参考文献 4 Matlab代码 数据
  • 自动驾驶离不开的仿真!Carla-Autoware联合仿真全栈教程

    随着自动驾驶技术的不断发展 研发技术人员开始面对一系列复杂挑战 特别是在确保系统安全性 处理复杂交通场景以及优化算法性能等方面 这些挑战中 尤其突出的是所谓的 长尾问题 即那些在实际道路测试中难以遇到的罕见或异常驾驶情况 这些问题暴露了实车
  • Making Large Language Models Perform Better in Knowledge Graph Completion论文阅读

    文章目录 摘要 1 问题的提出 引出当前研究的不足与问题 KGC方法 LLM幻觉现象 解决方案 2 数据集和模型构建
  • 对可变长度序列进行训练和预测

    传感器 同类型的 分散在我的网站上 不定期地手动向我的后端报告 在报告之间 传感器聚合事件并批量报告它们 以下数据集是批量收集的序列事件数据的集合 例如传感器 1 报告了 2 次 在第一批 2 个事件和第二批 3 个事件中 传感器 2 报告
  • 了解 Tensorflow LSTM 模型输入?

    我在理解 TensorFlow 中的 LSTM 模型时遇到一些困难 我用tflearn http tflearn org 作为包装器 因为它自动完成所有初始化和其他更高级别的工作 为了简单起见 我们考虑这个示例程序 https github
  • seq2seq 中的 TimeDistributed(Dense) 与 Dense

    鉴于下面的代码 encoder inputs Input shape 16 70 encoder LSTM latent dim return state True encoder outputs state h state c encod
  • 结合 CNN 和双向 LSTM

    我正在尝试结合 CNN 和 LSTM 进行图像分类 我尝试了以下代码 但收到错误 我有 4 个课程需要训练和测试 以下是代码 from keras models import Sequential from keras layers imp
  • Python - 基于 LSTM 的 RNN 需要 3D 输入?

    我正在尝试构建一个基于 LSTM RNN 的深度学习网络 这是尝试过的 from keras models import Sequential from keras layers import Dense Dropout Activatio
  • 将 CNN 的输出传递给 BILSTM

    我正在开发一个项目 其中我必须将 CNN 的输出传递给双向 LSTM 我创建了如下模型 但它抛出 不兼容 错误 请让我知道哪里出了问题以及如何解决这个问题 model Sequential model add Conv2D filters
  • 如何使用有状态 LSTM 和 batch_size > 1 布置训练数据

    背景 我想在 Keras 中对 有状态 LSTM 进行小批量训练 我的输入训练数据位于一个大矩阵 X 中 其维度为 m x n 其中 m number of subsequences n number of time steps per s
  • 尝试理解 Pytorch 的 LSTM 实现

    我有一个包含 1000 个示例的数据集 其中每个示例都有5特征 a b c d e 我想喂7LSTM 的示例 以便它预测第 8 天的特征 a 阅读 nn LSTM 的 Pytorchs 文档 我得出以下结论 input size 5 hid
  • Keras:嵌入 LSTM

    在 LSTM 的 keras 示例中 用于对 IMDB 序列数据进行建模 https github com fchollet keras blob master examples imdb lstm py https github com
  • 张量流:简单 LSTM 网络的共享变量错误

    我正在尝试构建一个最简单的 LSTM 网络 只是想让它预测序列中的下一个值np input data import tensorflow as tf from tensorflow python ops import rnn cell im
  • Caffe 的 LSTM 模块

    有谁知道 Caffe 是否有一个不错的 LSTM 模块 我从 russel91 的 github 帐户中找到了一个 但显然包含示例和解释的网页消失了 以前是http apollo deepmatter io http apollo deep

随机推荐

  • android nfc中Ndef格式的读写

    原文地址 检测到标签后在Activity中的处理流程 1 在onCreate 中获取NfcAdapter对象 NfcAdapter nfcAdapter NfcAdapter getDefaultAdapter this 2 在onNewI
  • 微信小程序 实现天气类功能

    参考链接 1 全国城市天气预报 城市天气预报查询 国内天气预报查询 天气网 https www tianqi com chinacity html 2 获取实时天气数据 获取数据 开发指南 微信小程序SDK 高德地图API https lb
  • apache字体文件跨域_css引用跨域字体文件woff,eot,ttf问题

    今天把站点的字体的静态文件woff eot ttf放到cdn去速度快一些 改成了外链地址 居然不加载报错 用下面的公用地址可以正常使用 https cdn bootcss com font awesome 4 7 0 fonts 搜索下 是
  • H5 页面采坑记录

    1 页面布局时 上下滑动页面时通常会把一些盒子放在 section section 标签中 但是在一些机型如iphonex测试中 上下滑动页面会出现都抖动的情况 不知道什么原因 解决方案就是 不使用 section 标签 直接在大盒子中写滚
  • 多线程之常用线程安全类型分析

    写在前面 本文一起看下在日常工作中我们经常用到的线程安全的数据类型 以及一些经验总结 1 常用线程安全数据类型 1 1 jdk集合数据类型 jdk的集合数据类型分为两类 一种是线性数据结构 另外一种是字典结构 分别看下 1 1 1 线性数据
  • 通过PyInstaller打包报“文件遇到错误”

    前言 不知道大家在作为python程序后 是不是都通过PyInstaller打包给用户使用呢 但是通过PyInstaller打包会出现一点小小的问题 本文章就来教你如何去解决这些问题 让打包后显示出控制台窗口 在打包的时候 不用加上 w让窗
  • 解码(二):音视频解码上下文创建配置和打开avcodec_open2打开演示

    如下代码 视频解码器打开 找到视频解码器 AVCodec vcodec avcodec find decoder ic gt streams videoStream gt codecpar gt codec id if vcodec cou
  • 远期与期货

    概述 期货合约与远期合约都是规定在将来的某一时间购买或者出售某项资产 这一点与期权类似 关键不同之处在于 期权持有者不会被强制购买或者出售资产 当无利可图时 可以选择放弃交易 但是 期货或者远期合约由必须履行事先约定的合约义务 远期 仅仅是
  • Java Lombok 报错(IllegalAccessError: class lombok.javac.apt.LombokProcessor)解决方法

    本文主要介绍Java 中 使用Lombok报错 java java lang IllegalAccessError class lombok javac apt LombokProcessor的解决方法及示例代码 原文地址 Java Lom
  • Java Swing 如何让界面更加美观

    文章目录 一 设置窗体的背景图 二 设置Button组件 三 设置字体大小和颜色 四 设置组件的背景色 五 综合测试案例 一 设置窗体的背景图 利用JLable类的构造方法或方法加载图片 ImageIcon image new ImageI
  • 设计一个雇员Employee类

    题目内容 设计一个雇员Employee类 具体要求如下 1 设计雇员Employee类 记录雇员的情况 包括姓名 年薪 受雇时间 String name double salary MyDate start 2 定义MyDate类作为日期
  • 装系统时提示 无法在驱动器0分区上安装windows

    先看提示 先看提示 先看提示 1 在重装系统时遇到一个问题 无法在驱动器0分区上安装windows 2 解决方法 1 在当前安装界面按住Shift F10调出命令提示符窗口 2 输入diskpart 按回车执行 3 进入DISKPART命令
  • 负数为什么要用补码来表示?

    上篇文章讲了 负数在计算机中是怎么存储的 看完之后 应该对原码 反码 补码有了基本的了解了 今天 我们深入探讨一下 为什么计算机中要用补码来表示负数 首先 我们应该清楚 原码是方便给人看的 看到一个数的原码 我们就能根据符号位和后边的二进制
  • [144]如何用VBS编写一个简单的恶搞脚本

    windows系统的电脑 首先右击桌面 选择新建 文本文档 在桌面上新建一个文本文档 随后打开计算机或者是我的电脑 点击其中的组织 xp系统多为工具 选择下面的文件夹和搜索选项 在弹出的窗口中点击查看 向下滚到 找到隐藏已知文件类型的扩展名
  • Android(Kotlin)获取应用全局上下文 ApplicationContext

    需求 Android Kotlin 获取应用全局上下文 ApplicationContext 有些场景下需要使用的 Context 是和页面无关的 仅和应用进程相关 比如 读写文件或访问数据库 这些场景下 我们希望可以在项目内任意位置 直接
  • Allegro PCB的布局

    1 手工导入元器件 place manually进入放置设置页面 在需要放置的元器件前面打勾 可以依次放置元器件 2 快速放置元器件 place Quickplace 使用快速放置功能需要先画好板宽outline才可以 3 设置room区域
  • c++实现数据结构栈和队列

    1 栈 头文件 ifndef ZHAN H define ZHAN H define MAX 8 include
  • laravel-admin安装及使用教程

    安装命令 安装 Laravel 安装器 composer global require laravel installer 创建名为 shopAdmin 项目 laravel new shopAdmin 经过漫长的等待已经安装好了 进入项目
  • springboot中注入FilterRegistrationBean不生效原因

    springboot中注入FilterRegistrationBean不生效原因 回顾 最近自定义了两个过滤器 接口请求返回加密和sql注入处理过滤器 因为在封装一些工具包 我在单独调好之后 就打算做成一个注解 像springboot启动类
  • 基于自注意力机制的LSTM多变量负荷预测

    1 引言 在之前使用长短期记忆网络构建电力负荷预测模型的基础上 将自注意力机制 Self Attention 融入到负荷预测模型中 具体内容是是在LSTM层后面接Self Attention层 在加入Self Attention后 可以将负