GRU解决预测分类问题(多变量预测多步)

2023-11-15

解决问题的背景:现有五个属性列,前四个属性列作为特征输入,第五个属性列作为标签值,第五个属性列的意义是类别;先需要通过前50步的数据特征预测后10步的类别(即:51-60步)。

1.直接多输出的方式:直接多输出的方式就是在神经网络的最后加上几个(对应的是需要预测步长是几步,这里是10)一样的全连接神经网络,在这一层之后进行对每个全连接神经网络输出的值的拼接得到一个10步长的结果,用于后面计算损失进行训练。

简单的网络结构如下图:

模型网络的代码如下:

# GRU
class GRURNN(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = torch.nn.GRU(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        self.fc1 = torch.nn.Linear(self.hidden_size, 4)
        self.fc2 = torch.nn.Linear(self.hidden_size, 4)
        self.fc3 = torch.nn.Linear(self.hidden_size, 4)
        self.fc4 = torch.nn.Linear(self.hidden_size, 4)
        self.fc5 = torch.nn.Linear(self.hidden_size, 4)
        self.fc6 = torch.nn.Linear(self.hidden_size, 4)
        self.fc7 = torch.nn.Linear(self.hidden_size, 4)
        self.fc8 = torch.nn.Linear(self.hidden_size, 4)
        self.fc9 = torch.nn.Linear(self.hidden_size, 4)
        self.fc10 = torch.nn.Linear(self.hidden_size, 4)
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, input_seq):
        batch_size = input_seq.shape[0]
        h_0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        output, _ = self.gru(input_seq,h_0)
        pred1 = self.fc1(output)
        pred2 = self.fc2(output)
        pred3 = self.fc3(output)
        pred4 = self.fc4(output)
        pred5 = self.fc5(output)
        pred6 = self.fc6(output)
        pred7 = self.fc7(output)
        pred8 = self.fc8(output)
        pred9 = self.fc9(output)
        pred10 = self.fc10(output)
        pred1, pred2, pred3, pred4, pred5, pred6, pred7, pred8, pred9, pred10 = pred1[:, -1, :], pred2[:, -1, :], pred3[:, -1, :], pred4[:, -1, :], pred5[:, -1, :], pred6[:, -1, :], pred7[:, -1, :], pred8[:, -1, :], pred9[:, -1, :], pred10[:, -1, :]
        pred1, pred2, pred3, pred4, pred5, pred6, pred7, pred8, pred9, pred10 = self.softmax(pred1), self.softmax(pred2), self.softmax(pred3), self.softmax(pred4), self.softmax(pred5), self.softmax(pred6), self.softmax(pred7), self.softmax(pred8), self.softmax(pred9), self.softmax(pred10)
        pred = torch.stack([pred1, pred2, pred3, pred4, pred5, pred6, pred7, pred8, pred9, pred10], dim=1)
        return pred

2.滚动数据集输出的方式:滚动数据集的方式就是单步预测的一个整合的版本,具体就是先用前50步预测第51步然后用2-51步作为50步的值进行下一次的输入预测第52步,以此类推;这里后面预测完加入到输入数据中的新值可以就是刚刚预测出来的新值,也可以是数据标签集值的对应到这一步的值。滚动预测的效果会比直接多输出的方式的效果好,但是时间是较长的,对于需要一个较好性能模型的需求来说,时间久一点不是什么问题。

简单的预测步骤如下图(简单表示:5步预测3步):

 模型网络的代码如下:

# GRU
class GRURNN(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = torch.nn.GRU(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(self.hidden_size, 32),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(32, 16),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(16, 4)
        )
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, input_seq):
        batch_size = input_seq.shape[0]
        h_0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        output, _ = self.gru(input_seq,h_0)
        output = output[:, -1, :] 
        pred = self.mlp(output)
        pred = self.softmax(pred)
        return pred


# 直接单步滚动,预测未来多步的预测
class GRURNN_PRO_MORE(torch.nn.Module):
    def __init__(self,gru,device):
        super(GRURNN_PRO_MORE, self).__init__()
        self.gru = gru
        self.device = device
    def forward(self, src, trg):
        batch_size = src.shape[0]
        src_len = src.shape[1]
        trg_len = trg.shape[1]
        output_size = 4

        outputs = torch.zeros(batch_size, trg_len, output_size).to(self.device)

        for i in range(trg_len):
            src = src.float()
            output = self.gru(src)
            outputs[:, i, :] = output
            trg_input = trg[:, i, :4].reshape([batch_size, 1, output_size])
            src = torch.cat((src[:, 1:, :], trg_input),dim=1)
        return outputs

数据处理:对于数据的处理用到了常用的一些库,像pandas,numpy等。

作者处于学习阶段,如有错误,欢迎批评指正。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

GRU解决预测分类问题(多变量预测多步) 的相关文章

随机推荐

  • java中的异常 最详细的讲解

    一 异常的概念 异常 在程序执行的过程中 出现的非正常情况 最终会导致JVM非正常停止 在Java等面向对象的编程语言中 异常本身是一个类 产生异常就是创建异常对象并抛出了一个异常对象 Java处 理异常的方式是中断处理 异常指的并不是语法
  • Nginx基本入门

    本文转载至 http blog csdn net u012486840 article details 53098890 1 静态HTTP服务器 首先 Nginx是一个HTTP服务器 可以将服务器上的静态文件 如HTML 图片 通过HTTP
  • [leetcode] 适合打劫银行的日子 -前缀和

    题目链接 前缀和思想 用数组 l l l 表示前面有多少个数满足 a i
  • MySQL----内置函数

    MySQL gt 内置函数 函数 将经常使用的代码封装起来 需要的时候直接调用就可以 从函数定义角度 函数可分为 内置函数 系统内置的通用函数 自定义函数 需要根据需求编写的函数 MySQL提供的内置函数从实现的功能角度可以分为数值函数 字
  • 多益视频面试

    多益面试 有一种怀疑人生的感觉 向老师 我对不起你 去年刚学的网络安全 我竟然没说出来加密算法的名字 也并不是题很难 而是简单的就是说不出来 写不出来 而难的也就是听过而已 问题 1 说一下什么是线程安全 线程安全的场景 线程安全就是确保程
  • 单相机做分屏混合

    做了一个单相机实现分屏混合的功能 需求大概就是在同一视角下 相机通过不同的CullingMask获取不同的渲染图片RenderTexture之后 通过某种方式一起显示在界面的功能 其实核心逻辑就是怎样用一个相机渲染不同的图片罢了 直接上代码
  • 在Java中,将ExecutorService转为守护程序

    问题描述 我正在Java 1 6中使用一个ExecutoreService 简单地开始 ExecutorService pool Executors newFixedThreadPool THREADS 当我的主线程完成 以及由线程池处理的
  • vue调用原生方法交互

    目前在做一个H5应用 需要调用原生方法进行交互 特此做一个记录 技术栈 vue版本2 6 vant版本 2 12 第一步 声明一个rpcFn js文件 进行原生交互阿里桥封装 const rpc function url params re
  • github actions实现Android持续集成

    持续集成 Continuous Integration 在很多单位都有现成的系统 但是作为一名工程师 我们还是要了解其原理 可以自己尝试做一下 经过本人的尝试 发现功能并不复杂 这里把持续集成实践经验总结与大家分享 持续集成用的比较多的是j
  • 2023自动化专业毕业设计项目集合

    文章目录 1前言 2 如何选题 2 1 物联网方向 2 2 嵌入式开发方向 2 3 人工智能方向 2 4 算法研究方向 2 5 学长作品展示 4 最后 1前言 近期不少学弟学妹询问学长关于自动化专业相关的毕设选题 学长特意写下这篇文章以作回
  • 基于springcloud gateway + nacos实现灰度发布(reactive版)

    什么是灰度发布 灰度发布 又名金丝雀发布 是指在黑与白之间 能够平滑过渡的一种发布方式 在其上可以进行A B testing 即让一部分用户继续用产品特性A 一部分用户开始用产品特性B 如果用户对B没有什么反对意见 那么逐步扩大范围 把所有
  • 一个网站引发的程序猿的牢骚,哈哈哈

    2013年大学毕业后 参加工作做的第一个前端项目 北京服装学院 今天调研一个关于iframe的需求 突然想试试 以前那些做IE6兼容的项目是否还在使用 就默默的点开了 十年了 他们没有换网站 我的岁月似乎从这一刻又回来了一次 已经十年了 我
  • Flask学习笔记(二)

    Flask学习笔记 二 1 知识点 1 1虚拟环境 1 1 1virtualenv 1 1 2virtualenvwrapper 1 2web与视图 1 3jinja2 1 3 1template知识点 1 3 2豆瓣列表页 1 3 3视图
  • 锚框损失论文下载 Iou-Loss【IoU Loss、GIoU Loss、 DIoU Loss 、CIoU Loss、 CDIoU Loss、 F-EIoU Loss、α-IoU Loss】

    锚框损失 Iou Loss IoU Loss GIoU Loss DIoU Loss CIoU Loss CDIoU Loss F EIoU Loss IoU Loss 论文打包下载 yolo系列论文https download csdn
  • cocosCreator2.3.x渲染流程深入剖析笔记(三)

    渲染批次合并之顶点 根据前面说过的render flow流程接下来就是重头戏了render流程 其中包括了 检查两个渲染节点是否可以合并 同时把renderData的数据填充到modelBatch里的buffer中去 所有需要渲染的节点都有
  • Kotlin中匿名函数(又称为Lambda,或者闭包)和高阶函数的详解

    博主前些天发现了一个巨牛的人工智能学习网站 通俗易懂 风趣幽默 忍不住也分享一下给大家 点击跳转到教程 1 匿名函数 fun main 匿名函数 1 定义时不取名字的函数 我们称之为匿名函数 匿名函数通常整体传递给其他函数 或者从其他函数返
  • java中到底该不该用@author标识作者?

    今天查看activiti的README 突然发现一段很有意思的FAQ Why do you not accept author lines in your source code Because the author tags in the
  • Redis基础

    一 Redis入门 1 Redis简介 Redis Remote Dictionary Server 即远程字典服务 是一个基于内存的key value结构数据库 是用C语言开发的一个开源的高性能键值对 key value 数据库 它可以用
  • 基于python 蔬菜价格数据分析 完整代码+数据

    https download csdn net download weixin 55771290 87567123
  • GRU解决预测分类问题(多变量预测多步)

    解决问题的背景 现有五个属性列 前四个属性列作为特征输入 第五个属性列作为标签值 第五个属性列的意义是类别 先需要通过前50步的数据特征预测后10步的类别 即 51 60步 1 直接多输出的方式 直接多输出的方式就是在神经网络的最后加上几个