transformer序列预测示例

2023-11-12

本文参考:

【python量化】将Transformer模型用于股票价格预测_蚂蚁爱Python的博客-CSDN博客_transformer 股票

一、Transformer初步架构图

二、transformer各组件信息

1、positional_encoder

X -> [positional_encoding(embedding + position)] -> X‘

2、multi_head_attention_forward

输入:

Query:(L, N, E)

Key:(S, N, E)

Value:(S, N, E)

输出:

Attn_output:(L,N,E)

Attn_output_weights:(N,L,S)

3、MultiheadAttention

输入:

Embed_dim

Num_heads

Multihead_attn = MultiheadAttention(embed_dim, num_heads)

Attn_output, attn_output_weights = multihead_attn(query, key, value)

4、TransformerEncoderLayer

输入:

d_model,

n_head

5、TransformerEncoder

输入:

Encoder_layer

Num_layers

Encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)

Transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)

6、TransformerDecoderLayer

输入:

d_model

n_head

7、TransformerDecoder

输入:

Decoder_layer

Num_layers

Decoder_layer = TransformerDecoder(d_model=512, nhead=8)

Transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6)

8、transformer

输入:

Nhead

Num_encoder_layers

Transformer_model = Transformer(nhead=16, num_encoder_layers=12)

Src = torch.rand((10, 32, 512))

Tgt = torch.rand((20, 32, 512))

Out = transformer_model(src, tgt)

三、transformer进行序列预测示例

1、csv输入序列的shape为(126, 1)

2、切分训练集(88, )和测试集(38, )

3、滑动窗口input_window得到(序列值, 标签值)

得到的训练序列为(67, 2, 20)

67:窗口滑动次数;

2:序列+标签=2个

20:输入窗口长度

得到测试序列(17, 2, 20)

17:窗口滑动次数

2:序列+标签=2个

20:输入窗口长度

4、将input和target转成模型需要的格式(S=20,N=2, E=1)

5、模型训练至掩码生成(20, 20)

6、Pos_encoder

Src(20, 2, 1)  + positional_encoder(20, 1, 250) => (20, 2, 250)

广播机制完成

7、transformer_encoder:(20, 2, 250)-> (20, 2, 250)

8、decoder:(20, 2, 250) -> (20, 2, 1) 通过Linear完成

9、loss值计算

因为有mask存在,所以每次预测下个值时的后续值是未知的。计算loss的过程如下:

 四、完整代码(针对原文中错误修改过的)

import torch
import torch.nn as nn
import numpy as np
import time
import math
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import pandas as pd

torch.manual_seed(0)
np.random.seed(0)

input_window = 20
output_window = 1
batch_size = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class TransAm(nn.Module):
    def __init__(self, feature_size=250, num_layers=1, dropout=0.1):
        super(TransAm, self).__init__()
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(feature_size)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=feature_size, nhead=10, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.decoder = nn.Linear(feature_size, 1)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0,1)
        mask = mask.float().masked_fill(mask==0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask


    def forward(self, src):
        if self.src_mask is None or self.src_mask.shape[0] == len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask
            src = self.pos_encoder(src)
            output = self.transformer_encoder(src, self.src_mask)
            output = self.decoder(output)
            return output

def create_inout_sequences(input_data, tw):
    inout_seq = []
    L = len(input_data)
    for i in range(L - tw):
        train_seq = input_data[i: i + tw]
        train_label = input_data[i + output_window: i + tw + output_window]
        inout_seq.append((train_seq, train_label))
    return torch.FloatTensor(inout_seq)


def get_data():
    series = pd.read_csv("D:\\temp\\0001_daily.csv", usecols=[0])
    scaler = MinMaxScaler(feature_range=[-1, 1])
    series = scaler.fit_transform(series.values.reshape(-1,1)).reshape(-1)
    train_samples = int(0.7 * len(series))
    train_data = series[:train_samples]
    test_data = series[train_samples:]
    train_sequence = create_inout_sequences(train_data, input_window)
    train_sequence = train_sequence[:-output_window]
    test_data = create_inout_sequences(test_data, input_window)
    test_data = test_data[:-output_window]
    return train_sequence.to(device), test_data.to(device)

def get_batch(source, i, batch_size):
    seq_len = min(batch_size, len(source) - 1 - i)
    data = source[i : i + seq_len]
    input = torch.stack(torch.stack([item[0] for item in data]).chunk(input_window, 1))
    target = torch.stack(torch.stack([item[1] for item in data]).chunk(input_window, 1))
    return input, target


def train(train_data):
    model.train()
    total_loss = 0

    for batch_index, i in enumerate(range(0, len(train_data) - 1, batch_size)):
        start_time = time.time()
        data, targets = get_batch(train_data, i, batch_size)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), 0.7)
        optimizer.step()
        total_loss += loss.item()

        log_interval = int(len(train_data) / batch_size / 5)
        if batch_index % log_interval == 0 and batch_index > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.6f} | {:5.2f} ms | loss {:5.5f} | ppl {:8.2f}'
                  .format(epoch, batch_index, len(train_data) // batch_size, scheduler.get_lr()[0],
                          elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))

def evaluate(eval_model, data_source):
    eval_model.eval()
    total_loss = 0
    eval_batch_size = 1000

    with torch.no_grad():
        for i in range(0, len(data_source) - 1, eval_batch_size):
            data, targets = get_batch(data_source, i, eval_batch_size)
            output = eval_model(data)
            total_loss += len(data[0]) * criterion(output, targets).cpu().item()
        return total_loss / len(data_source)

def plot_and_loss(eval_model, data_source, epoch):
    eval_model.eval()
    total_loss = 0.
    test_result = torch.Tensor(0)
    truth = torch.Tensor(0)
    with torch.no_grad():
        for i in range(0, len(data_source) - 1):
            data, target = get_batch(data_source, i, 1)
            output = eval_model(data)
            total_loss += criterion(output, target).item()
            test_result = torch.cat((test_result, output[-1].view(-1).cpu()), 0)
            truth = torch.cat((truth, target[-1].view(-1).cpu()), 0)

        plt.plot(test_result, color='red')
        plt.plot(truth, color='blue')
        plt.grid(True, which='both')
        plt.axhline(y=0, color='k')
        plt.savefig('transformer-epoch%d.png' % epoch)
        plt.close()

    return total_loss / i



train_data, val_data = get_data()
model = TransAm().to(device)
criterion = nn.MSELoss()
lr = 0.005
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.95)
epochs = 60

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(train_data)

    if(epoch % 10 is 0):
        val_loss = plot_and_loss(model, val_data, epoch)
    else:
        val_loss = evaluate(model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.5f} | valid ppl {:8.2f}'.format(epoch, (
                time.time() - epoch_start_time), val_loss, math.exp(val_loss)))
    print('-' * 89)
    scheduler.step()

模型使用示例:

保存模型:
torch.save(model.state_dict(), 'transformer.pth.tar')


使用模型:
import os

model_path = os.path.join(os.getcwd(), 'transformer.pth.tar')
model_dict = torch.load(model_path)
model.load_state_dict(model_dict)

test_data = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0])
test_data = torch.from_numpy(test_data).view(-1, 1, 1).to('cuda').float()  # seq_len, batch_size, embedding
test_output = model(test_data)
print(test_output[-1].item())
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

transformer序列预测示例 的相关文章

随机推荐

  • java static变量_什么是静态(static)?什么是静态方法,静态变量,静态块和静态类...

    static是Java中的一个关键字 我们不能声明普通外层类或者包为静态的 static用于下面四种情况 静态变量 我们可以将类级别的变量声明为static 静态变量是属于类的 而不是属于类创建的对象或实例 因为静态变量被类的所有实例共用
  • Excel解决CSV文件中的乱码

    背景 对于CSV文件中的乱码问题 大概率是编码的问题 可以基于Excel进行编码转换 或将文本进行编码转化 解决办法 打开Excel应用 点击文件 新建文件 点击文件 选择导入 导入具体的CSV文件 选择CSV文件 点击完成即可 然后就可以
  • 麻雀算法(SSA)优化长短期记忆神经网络的数据分类预测,SSA-LSTM分类预测,多输入单输出模型

    清空环境变量 warning off 关闭报警信息 close all 关闭开启的图窗 clear 清空变量 clc 清空命令行 读取数据 res xlsread 数据集 xlsx num res size res 1 样本数 每一行 是一
  • vue + css气泡图动态气泡图

    div ul class bubbleUl li class bubbleLi div class textBubble span item value span div div class topDiv div style width 1
  • 从一个类调用另一个类的方法或属性

    package 练习 class yu String m 人工小智能 public void shout1 System out println 我是 m 今年18岁 同类中直接调用了m public void shout2 yu p ne
  • grep指令详解

    shell grep指令详解 grep 参数 e 使用PATTERN作为模式 这可以用于指定多个搜索模式 或保护以连字符 开头的图案 指定字符串做为查找文件内容的样式 f 指定规则文件 其内容含有一个或多个规则样式 让grep查找符合规则条
  • 【SQL注入】堆叠注入

    目录 一 简介 概述 原理 优势 前提 防护 二 分析堆叠注入 使用MYSQL 第一步 使用堆叠查询构造多条语句 第二步 查看语句是否成功执行 第三步 删除test 再查询 第四步 执行其它查询语句 一 简介 概述 顾名思义 就是多条语句堆
  • 【电脑使用】chm文件打开显示确保Web地址 //ieframe.dll/dnserrordiagoff.htm#正确

    问题描述 最近找到一个之前的一个chm文件 打开的时候内容是空白的 同时报错 确保Web地址 ieframe dll dnserrordiagoff htm 正确 如下图所示 参考链接 解决方案 根据文章中提示的方法 找到了原因所在 chm
  • 详解微信小程序支付流程

    小程序微信支付图 微信小程序的商户系统一般是以接口的形式开发的 小程序通过调用与后端约定好的接口进行参数的传递以及数据的接收 在小程序支付这块 还需要跟微信服务器进行交互 过程大致是这样的 一 小程序调用登录接口获取code 传递给商户服务
  • linux传输文件指令

    使用scp传输 从本地传到服务器 scp P 目的端口 本地路径 目的用户名 目的IP 目的路径 r参数可用来传文件夹 scp r P 使用sftp传输 sftp oPort 目的端口号 目的用户名 目的IP get下载 put上传
  • 这5个开源和免费静态代码分析工具,你一个都没有用过吗?不会吧

    如果您是软件开发人员或代码安全分析师 则通常需要分析源代码以检测安全漏洞并维护安全的质量代码 但是您的代码中可能存在许多难以手动发现的问题 毕竟 我们仍然是人类 因此即使是最高级的安全分析师也都会错过一些安全漏洞 我们提供了源代码分析工具功
  • MySql中4种批量更新的方法

    MySql中4种批量更新的方法 mysql 批量更新共有以下四种办法1 replace into 批量更新 replace into test tbl id dr values 1 2 2 3 x y 例子 replace into boo
  • 人生顿悟之宽以待人,严以律己

    台风已经过去了 天气也渐渐地晴朗了 但是不知道为什么自己的心情却越发觉得沉重起来 总觉得生活中少了点什么 是没有了以往的激情 还是多了几分压力 看了近1个月的房子 两个人的所有积蓄加上两家人的积蓄 勉强可以付得起首付 接下去就是了无止境的房
  • 开启MySQL主从半同步复制

    记录配置mysql主从半同步复制的过程 加载lib 所有主从节点都要配置 主库 install plugin rpl semi sync master soname semisync master so 从库 install plugin
  • Android Studio从一个activity到另一个activity

    Android Studio从一个activity跳转到另一个activity 简单的跳转 创建两个activity 创建跳转按钮 在第一个activity的onCreate中添加按钮监听事件 编写内部类 button setOnClick
  • [网盘工具/百度网盘]秒传链接的使用 -2022版油猴网页脚本

    注 此项技术仅针对百度网盘有效 软件要求 Chrome或Firefox等支持tampermonkey Violentmonkey的浏览器 1 什么是秒传链接 度盘秒传链接 标准提取码 由128位 32个16进制数 128位 32个16进制数
  • Neon intrinsics

    1 介绍 在上篇中 介绍了ARM的Neon 本篇主要介绍Neon intrinsics的函数用法 也就是assembly之前的用法 NEON指令是从Armv7架构开始引入的SIMD指令 其共有16个128位寄存器 发展到最新的Arm64架构
  • mysql设置wait timeout_mysql修改wait_timeout_MySQL

    bitsCN com mysql修改wait timeout mysql mysql gt show global variables like wait timeout 其默认值为8小时 mysql的一个connection空闲时间超过8
  • TypeScript中类的继承

    特点 避免重复创建类 减少代码数量 通过extends关键字继承父类 通过super继承父类的属性和方法 实例 class Person1 定义属性 name string age number gender string construc
  • transformer序列预测示例

    本文参考 python量化 将Transformer模型用于股票价格预测 蚂蚁爱Python的博客 CSDN博客 transformer 股票 一 Transformer初步架构图 二 transformer各组件信息 1 position