基于TCN时间卷积网络(含因果膨胀卷积)的单特征输入股票预测项目实战(pytorch)(一维特征)【有数据集和代码,可运行】

2023-11-08

一、项目简介

股票预测是金融领域中的重要问题,通过对历史股票数据的分析和建模,我们可以尝试预测未来股票的价格趋势,为投资决策提供参考。本项目是基于PyTorch深度学习框架实现一个使用时间卷积网络(TCN,Temporal Convolutional Network)来进行股票预测的项目,该网络通过堆叠因果卷积和扩张卷积,能够捕捉时间序列依赖关系和特征,可以有效地处理时间序列数据,最后进行预测。网络的输入是历史股票收盘价(单特征),输出是预测的股票收盘价。【本项目的代码文件分模块整理,包含模型构建、数据划分、训练过程等模块都清晰分明】

二、TCN时间卷积网络简介

该算法于2016年由Lea等人他们在做视频动作分割的研究首先提出,CNN模型以 CNN 模型为基础,并做了如下改进:①适用序列模型:因果卷积(Causal Convolution)②记忆历史:空洞卷积/膨胀卷积(Dilated Convolution),残差模块(Residual block)。

1、因果卷积(Causal Convolution)

因果卷积是一种卷积操作,主要应用于时间序列数据或具有时序性的数据分析任务。因果卷积在卷积操作中引入了因果性,确保输出只依赖于过去的输入数据,不受未来信息的影响。(下图为因果卷积过程)

2、膨胀因果卷积(Dilated Causal Convolutions)

膨胀因果卷积(Dilated Causal Convolution)是在因果卷积的基础上引入了膨胀(Dilation)操作的一种卷积操作。它结合了因果卷积和扩张卷积(Dilated Convolution)的特性,能够增加感受野(Receptive Field)并保持因果性。(下图为膨胀因果卷积,相比于上图,以较少的计算代价实现了更大的感受野)

 3、TCN整体框架

整体的TCN模型如下图所示,较深的网络结构可能会引起梯度消失等问题,该模型利用了一种类似于ResNet中的残差块的结构,这样设计的TCN结构更加的具有泛化能力。

三、实验数据集

实验采用的是深沪300数据集sh300.csv(后文有源码和数据集获取方式),这是公开的数据集,百度一下应该也可以找得到,数据集展示如下(所展示的数据是sh300_test.csv,本人在原数据集基础上删掉某些列的数据),实验只使用了红框中的收盘价作为输入特征进行预测。

数据划分:以滑窗的方式进行数据划分,滑窗大小为20,输入特征为1,每次滑窗的第21天为预测的标签值。

四、实验环境

平台:Window 11

语言:python3.9

编译器:Pycharm

框架:Pytorch:1.13.1

五、实验内容及部分代码展示

1、model_TCN.py 模型构建

model_TCN.py定义了项目用到的网络TCN模型,该模型由多个组件组成,包括Crop模块、TemporalCasualLayer模块、TemporalConvolutionNetwork模块和TCN模块。Crop模块用于裁剪输入张量的时间维度,去除多余的padding部分。TemporalCasualLayer模块实现了一个膨胀卷积层,由两个膨胀卷积块组成。每个膨胀卷积块包含一个带有权重归一化的卷积层、裁剪模块、ReLU激活函数和Dropout正则化。此外,还包括一个用于快捷连接的卷积层。TemporalConvolutionNetwork模块通过堆叠多个TemporalCasualLayer组成了一个完整的TCN网络。每个TemporalCasualLayer具有不同的膨胀系数,并根据输入和输出通道的数量进行设置。TCN模块封装了TemporalConvolutionNetwork,并添加了一个线性层用于最终的预测。在前向传播中,先经过TCN网络,然后将输出的最后一个时间步传入线性层,并通过ReLU激活函数进行非线性变换。通过以上组件的组合,该TCN模型可以用于时间序列建模任务,并在最后输出预测结果。(本来网络最后一层是没有激活Relu的,后来发现效果不是很好,我在最后一层加上激活Relu后,实现预测的效果还不错)

import torch.nn as nn
from torch.nn.utils import weight_norm

#用于裁剪输入张量的时间维度,去除多余的 padding 部分。
class Crop(nn.Module):

    def __init__(self, crop_size):
        super(Crop, self).__init__()
        self.crop_size = crop_size

    def forward(self, x):
        #裁剪张量以去除额外的填充
        return x[:, :, :-self.crop_size].contiguous()

#实现了一个膨胀卷积层,由两个膨胀卷积块组成。每个膨胀卷积块包含一个带有权重归一化的卷积层、裁剪模块、ReLU激活函数和 Dropout 正则化。此外,还包括了一个用于快捷连接的卷积层
class TemporalCasualLayer(nn.Module):

    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, dropout=0.2):
        super(TemporalCasualLayer, self).__init__()
        padding = (kernel_size - 1) * dilation
        conv_params = {
            'kernel_size': kernel_size,
            'stride': stride,
            'padding': padding,
            'dilation': dilation
        }

        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, **conv_params))
        self.crop1 = Crop(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, **conv_params))
        self.crop2 = Crop(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.crop1, self.relu1, self.dropout1,
                                 self.conv2, self.crop2, self.relu2, self.dropout2)
        #快捷连接
        self.bias = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()

    def forward(self, x):
        # 应用因果卷积和快捷连接
        y = self.net(x)
        b = x if self.bias is None else self.bias(x)
        return self.relu(y + b)

#通过堆叠多个 TemporalCasualLayer 组成了一个完整的 TCN 网络。每个 TemporalCasualLayer 具有不同的膨胀系数,并根据输入和输出通道的数量进行设置。
class TemporalConvolutionNetwork(nn.Module):

    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvolutionNetwork, self).__init__()
        layers = []
        num_levels = len(num_channels)
        tcl_param = {
            'kernel_size': kernel_size,
            'stride': 1,
            'dropout': dropout
        }
        for i in range(num_levels):
            dilation = 2 ** i
            in_ch = num_inputs if i == 0 else num_channels[i - 1]
            out_ch = num_channels[i]
            tcl_param['dilation'] = dilation
            tcl = TemporalCasualLayer(in_ch, out_ch, **tcl_param)
            # tcl = self.relu(tcl)
            layers.append(tcl)

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

#封装了 TemporalConvolutionNetwork,并添加了一个线性层用于最终的预测。在前向传播中,先经过 TCN 网络,然后将输出的最后一个时间步传入线性层,并通过 ReLU 激活函数进行非线性变换。
class TCN(nn.Module):

    def __init__(self, input_size, output_size, num_channels, kernel_size, dropout):
        super(TCN, self).__init__()
        self.tcn = TemporalConvolutionNetwork(input_size, num_channels, kernel_size=kernel_size, dropout=dropout)
        self.linear = nn.Linear(num_channels[-1], output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        # 应用TCN和线性层,然后使用ReLU激活函数
        y = self.tcn(x)  # [N,C_out,L_out=L_in]
        return self.relu(self.linear(y[:, :, -1]))

 2、train.py 训练通用模板

3、Config.py 参数定义

config类中定义了项目所有需要的参数,可以在里面修改训练参数。

 4、DataSplit.py 数据划分

 DataSplit.py 是实现数据划分的函数,通过滑动窗口,将每个窗口大小的数据作为训练数据,将其后面一个数据作为预测结果,再进行划分训练数据和标签,最后分成训练集和验证集。

 5、test_stock_TCN_run.py 训练文件

该py文件实现整体训练流程并做绘图操作。依次实现加载数据、数据标准化、取出WIND数据、划分训练集测试集、数据转化为Tensor、形成数据更迭器、载入模型、定义损失、定义优化器、开始训练、损失可视化、显示预测结果。

6、test_pth.py 模型训练后的测试文件

采用模型训练完成后的pth对数据进行预测,可以展示模型预测效果,前面的处理过程类似test_wind_CNN.py所示。

7、loss_draw.py 模型训练后的loss绘图

将训练后产生并收集的loss.csv展示出来,也就是损失图,红框可调展示范围。 

 

 六、实验结果及分析

1、loss损失图

该损失是训练了200个epoch的损失图:

 纵坐标放大局部范围展示:

 2、预测效果展示

训练epoch=200后的股票收盘价预测效果如下(使用pth参数文件进行测试预测):

 局部展示(展示前两百天的预测效果):

七、总结及资源

若有朋友需要可运行的源码和数据集,可以guan注【科研小条】公众号,回复【股票预测TCN】,即可获得。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于TCN时间卷积网络(含因果膨胀卷积)的单特征输入股票预测项目实战(pytorch)(一维特征)【有数据集和代码,可运行】 的相关文章

随机推荐

  • VM VirtualBox 全屏模式 && 自动缩放模式 相互切换

    1 自动缩放模式 热键Host C 偶然一次机会 把虚拟机切换为了自动缩放模式 如下图 想要再切换为全屏模式 发现不知如何操作 后来折腾了一会儿 切换成功 以此备录一下 2 切换为全屏模式 热键Host F 切换为全屏模式的快捷键为Host
  • liveshare开发体验 vs_imgcook体验

    D2今年收费了 我所在创业公司没有报销 当然门票也不是什么大钱 无奈忙成狗错过了早鸟票 指望后面看看分享ppt 无意中看到D2官方流出的一个感兴趣的网址 说是 可以由视觉稿一键生成代码 https imgcook taobao org 创业
  • Kafka、RabbitMQ、RocketMQ 消息中间件的对比

    什么是消息队列 消息队列是在消息的传输过程中保存消息的容器 包含以下 3 元素 Producer 消息生产者 负责产生和发送消息到 Broker Broker 消息处理中心 负责消息存储 确认 重试等 一般其中会包含多个 Queue Con
  • 3DCAT实时云渲染助力VR虚拟现实迈向成熟

    近年来 虚拟现实 Virtual Reality VR 技术在市场上的应用越来越广泛 虚拟现实已成为一个热门的科技话题 相关数据显示 2019年至2021年 我国虚拟现实市场规模不断扩大 从2019年的282 8亿元增长至2021年的583
  • qt Model_View_Delegate 模型_视图_代理

    QT当中model view delegate 模型 视图 代理 此结构实现数据和界面的分离 Qt的模型 视图结构分为三部分 模型 model 视图 view 代理 Delegate 其中模型与数据源通信 并为其它部件提供接口 视图从模型中
  • CSS动画-Animation

    一 动画介绍 动画 animation 是CSS3中具有颠覆性的特征之 可通过设置多个节点来精确控制一个或一组动画常用来实现复杂的动画效果 相比较过渡 动画可以实现更多变化 更多控制的效果 二 动画组成 制作动画分为两个部分 keyfram
  • 立体匹配中的NCC,SAD,SSD算法

    常用的基于区域的局部匹配准则主要有图像序列中对应像素差的绝对值 SAD Sum of Absolute Differences 图像序列中对应像素差的平方和 SSD Sum of Squared Differences 图像的相关性 NCC
  • HSqlDB(java内置数据库)

    1 HSqlDB简介 HSQLDB是一款Java内置的数据库 非常适合在用于快速的测试和演示的Java程序中 无需独立安装数据库 HSQLDB有三种模式 1 Server 就像Mysql那样 2 In Process 又叫做 Standal
  • OpenCV颜色查找表

    Mat color imread flover jpeg Mat lut Mat zeros 256 1 CV 8UC3 for int i 0 i lt 256 i lut at
  • 《每日一题》NO.18:哪些因素会影响标准单元的延迟?

    芯司机 每日一题 会每天更新一道IC面试笔试题 其中有些题目已经被很多企业参考采用了哦 聪明的你快来挑战一下吧 今天是第18题 标准单元是RTL2GDS流程的基础 哪些因素会影响到标准单元的延迟呢 我们在工程项目中应该如何处理这些因素呢 快
  • springboot2

    springboot2 springboot2 核心功能 配置文件 web开发 数据访问 Junit5测试 actutor生产指标监控 springboot核心原理解析 springboot2场景整合 虚拟化技术 安全控制 缓存技术 消息中
  • 什么是SQL注入式攻击,如何去防范SQL注入式攻击

    一 SQL注入式攻击 1 所谓SQL注入式攻击 就是攻击者把SQL命令插入到Web表单的输入域或页面请求的查询字符串 欺骗服务器执行恶意的SQL命令 2 在某些表单中 用户输入的内容直接用来构造 或者影响 动态SQL命令 或作为存储过程的输
  • 测试用例、缺陷报告示例子

    测试用例 用例标题的作用 让人更清晰直观的查看 前置条件和测试步骤 测试步骤是在前置条件的基础上进行的 合格测试用例标题 缺陷 缺陷的介绍 需求 规格 说明书中明确要求的功能 缺失 少功能 需求 规格 说明书中致命不应该出现的错误 功能错误
  • 【项目实战】C/C++语言带你实现:围棋游戏丨详细逻辑+核心源码

    每天一个编程小项目 提升你的编程能力 游戏介绍 下围棋的程序 实现了界面切换 选择路数 和围棋规则 也实现了点目功能 不过只有当所有死子都被提走才能点目 不然不准确 操作方法 鼠标操作 游戏截图 编译环境 VisualStudio2019
  • 看完这篇 教你玩转渗透测试靶机Vulnhub——The Planets:Mercury

    Vulnhub靶机The Planets Mercury渗透测试详解 Vulnhub靶机介绍 Vulnhub靶机下载 Vulnhub靶机安装 Vulnhub靶机漏洞详解 信息收集 漏洞发现 SSH登入 CVE 2021 4034漏洞提权 获
  • CCAI 2017

    阅读原文请点击 摘要 2017 中国人工智能大会 CCAI 2017 在杭州国际会议中心盛大召开 CCAI发起人 中国科学院院士 中国人工智能学会副理事长谭铁牛院士在大会首日主会场进行了现场致辞 7月22日 23日的 2017 中国人工智能
  • 你必须收藏的Github技巧

    一秒钟把Github项目变成前端网站 GitHub Pages大家可能都知道 常用的做法 是建立一个gh pages的分支 通过setting里的设置的GitHub Pages模块可以自动创建该项目的网站 这里经常遇到的痛点是 master
  • 详细解析赋值、浅拷贝和深拷贝的区别

    一 赋值 Copy 赋值是将某一数值或对象赋给某个变量的过程 分为下面 2 部分 基本数据类型 赋值 赋值之后两个变量互不影响 引用数据类型 赋址 两个变量具有相同的引用 指向同一个对象 相互之间有影响 对基本类型进行赋值操作 两个变量互不
  • 将gitlab的代码仓库实时备份到其他服务器

    首先 这个题目是不完全正确的 因为经过各种尝试 gitlab的仓库直接备份到远端 拷贝回来后是不能使用的 表现为gitlab中能看到项目 但每个项目的内容都无法读取出来 页面上会有报错提示 所以 最终采用的是实时备份gitlab的备份库 最
  • 基于TCN时间卷积网络(含因果膨胀卷积)的单特征输入股票预测项目实战(pytorch)(一维特征)【有数据集和代码,可运行】

    一 项目简介 股票预测是金融领域中的重要问题 通过对历史股票数据的分析和建模 我们可以尝试预测未来股票的价格趋势 为投资决策提供参考 本项目是基于PyTorch深度学习框架实现一个使用时间卷积网络 TCN Temporal Convolut