【RNN从入门到实战】GRU入门到实战——使用GRU预测股票。

2023-10-27

摘要

      GRU是LSTM网络的一种效果很好的变体,它较LSTM网络的结构更加简单,而且效果也很好,因此也是当前非常流形的一种网络。GRU既然是LSTM的变体,因此也是可以解决RNN网络中的长依赖问题。

       在LSTM中引入了三个门函数:输入门、遗忘门和输出门来控制输入值、记忆值和输出值。而在GRU模型中只有两个门:分别是更新门和重置门。具体结构如下图所示:
在这里插入图片描述
图中的update gate和reset gate分别表示更新门和重置门。
     更新门的作用类似于LSTM的遗忘和输入门。它决定前一时刻的状态信息哪些被丢弃和哪些被更新新信息,更新门的值越大说明前一时刻的状态信息带入越多。
    重置门控制前一状态有多少信息被写入到当前的候选集 h t ~ \tilde{h_{t}} ht~上,重置门越小,前一状态的信息被写入的越少。
详见下图:
在这里插入图片描述

GRU和LSTM的区别

在这里插入图片描述

GRU使用门控机制学习长期依赖关系的基本思想和 LSTM 一致,但还是有一些关键区别:

1、GRU 有两个门(重置门与更新门),而 LSTM 有三个门(输入门、遗忘门和输出门)。
2、GRU 并不会控制并保留内部记忆(c_t),且没有 LSTM 中的输出门。
3、LSTM 中的输入与遗忘门对应于 GRU 的更新门,重置门直接作用于前面的隐藏状态。
4、在计算输出时并不应用二阶非线性。

如何选择?

由于 GRU 参数更少,收敛速度更快,因此其实际花费时间要少很多,这可以大大加速了我们的迭代过程。 而从表现上讲,二者之间孰优孰劣并没有定论,这要依据具体的任务和数据集而定,而实际上,二者之间的 performance 差距往往并不大,远没有调参所带来的效果明显,我通常的做法,首先选择GRU作为基本的单元,因为其收敛速度快,可以加速试验进程,快速迭代,而我认为快速迭代这一特点很重要。如果实现没其余优化技巧,才会尝试将 GRU 换为 LSTM,看看有没有变化。

实战——使用GRU预测股票

制作数据集

数据的下载地址:http://quotes.money.163.com/trade/lsjysj_zhishu_000001.html
首先选择查询数据的范围,然后选择下载数据:
在这里插入图片描述
点击下载数据后,弹出个框,这个框里可以看到起止日期,数据的列,默认勾选全部,然后单击下载。
在这里插入图片描述
将下载后的数据放到工程的根目录下面,然后做一些处理。
第一步 读入csv文件,编码方式设置为gbk。
第二步 将日期列转为“年-月-日”这样的格式。
第三步 根据日期排序,数据默认是从降序,改为升序。
最后 保存到“sh.csv”中。
代码如下:

import pandas as pd

df = pd.read_csv("000001.csv", encoding='gbk')
df["日期"] = pd.to_datetime(df["日期"], format="%Y-%m-%d")
df.sort_values(by=["日期"], ascending=True)
df = df.sort_values(by=["日期"], ascending=True)
df.to_csv("sh.csv", index=False, sep=',')

到这里数据集制作完成了

读入数据集,构建时序数据。

generate_df_affect_by_n_days函数,通过一个序列来生成一个矩阵(用于处理时序的数据)。就是把当天的前n天作为参数,当天的数据作为label。
readData中的文件名为:sh.csv ,参数column,代表选用的数据,本次预测只用了一列数据,列名就是column,参数n就是之前模型中所说的n,代表column前n天的数据。train_end表示的是后面多少个数据作为测试集。

import pandas as pd
import matplotlib.pyplot as plt
import datetime
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader


def generate_df_affect_by_n_days(series, n, index=False):
    if len(series) <= n:
        raise Exception("The Length of series is %d, while affect by (n=%d)." % (len(series), n))
    df = pd.DataFrame()
    for i in range(n):
        df['c%d' % i] = series.tolist()[i:-(n - i)]
    df['y'] = series.tolist()[n:]
    if index:
        df.index = series.index[n:]
    return df


def readData(column='收盘价', n=30, all_too=True, index=False, train_end=-300):
    df = pd.read_csv("sh.csv", index_col=0, encoding='utf-8')
    df.fillna(0, inplace=True)
    df.replace(to_replace="None", value=0)
    del df["股票代码"]
    del df["名称"]
    df.index = list(map(lambda x: datetime.datetime.strptime(x, "%Y-%m-%d").date(), df.index))
    df_column = df[column].copy()
    df_column_train, df_column_test = df_column[:train_end], df_column[train_end - n:]
    df_generate_from_df_column_train = generate_df_affect_by_n_days(df_column_train, n, index=index)
    print(df_generate_from_df_column_train)
    if all_too:
        return df_generate_from_df_column_train, df_column, df.index.tolist()
    return df_generate_from_df_column_train

构建模型

模型见RNN类,采用GRU+全连接。隐藏层设置64,GRU的输出是64维的,所以设置全连接也是64。
TrainSet是数据读取类,将输出数据的最后一列做标签,前面的列做输入,类的写法是Pytorch的固定模式。

class RNN(nn.Module):
    def __init__(self, input_size):
        super(RNN, self).__init__()
        self.rnn = nn.GRU(
            input_size=input_size,
            hidden_size=64,
            num_layers=1,
            batch_first=True,
        )
        self.out = nn.Sequential(
            nn.Linear(64, 1),
        )
        self.hidden = None
    def forward(self, x):
        r_out, self.hidden = self.rnn(x)  # None 表示 hidden state 会用全0的 state
        out = self.out(r_out)
        return out


class TrainSet(Dataset):
    def __init__(self, data):
        # 定义好 image 的路径
        self.data, self.label = data[:, :-1].float(), data[:, -1].float()

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return len(self.data)

训练与测试

n为模型中的n
LR是模型的学习率
EPOCH是多次循环
train_end这个在之前的数据集中有提到。(注意是负数)
n = 30
LR = 0.0001
EPOCH = 100
train_end = -500
loss选用mse
预测的数据选择“收盘价”

n = 10
LR = 0.0001
EPOCH = 100
train_end = -500
# 数据集建立
df, df_all, df_index = readData('收盘价', n=n, train_end=train_end)
df_all = np.array(df_all.tolist())
plt.plot(df_index, df_all, label='real-data')
df_numpy = np.array(df)
df_numpy_mean = np.mean(df_numpy)
df_numpy_std = np.std(df_numpy)
df_numpy = (df_numpy - df_numpy_mean) / df_numpy_std
df_tensor = torch.Tensor(df_numpy)
trainset = TrainSet(df_tensor)
trainloader = DataLoader(trainset, batch_size=16, shuffle=True)
rnn = RNN(n)
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.MSELoss()
for step in range(EPOCH):
    for tx, ty in trainloader:
        output = rnn(torch.unsqueeze(tx, dim=0))
        loss = loss_func(torch.squeeze(output), ty)
        optimizer.zero_grad()  # clear gradients for this training step
        loss.backward()  # back propagation, compute gradients
        optimizer.step()
    print(step, loss)
    if step % 10:
        torch.save(rnn, 'rnn.pkl')
torch.save(rnn, 'rnn.pkl')
generate_data_train = []
generate_data_test = []
test_index = len(df_all) + train_end
df_all_normal = (df_all - df_numpy_mean) / df_numpy_std
df_all_normal_tensor = torch.Tensor(df_all_normal)
for i in range(n, len(df_all)):
    x = df_all_normal_tensor[i - n:i]
    x = torch.unsqueeze(torch.unsqueeze(x, dim=0), dim=0)
    y = rnn(x)
    if i < test_index:
        generate_data_train.append(torch.squeeze(y).detach().numpy() * df_numpy_std + df_numpy_mean)
    else:
        generate_data_test.append(torch.squeeze(y).detach().numpy() * df_numpy_std + df_numpy_mean)
plt.plot(df_index[n:train_end], generate_data_train, label='generate_train')
plt.plot(df_index[train_end:], generate_data_test, label='generate_test')
plt.legend()
plt.show()
plt.cla()
plt.plot(df_index[train_end:-400], df_all[train_end:-400], label='real-data')
plt.plot(df_index[train_end:-400], generate_data_test[:-400], label='generate_test')
plt.legend()
plt.show()

在这里插入图片描述
在这里插入图片描述

总结

本实例数据集使用上证的收盘价,构建时序数据集,使用GRU对数据预测,从结果上来看,走势有一定的滞后性,由于只是用了价格预测价格,还是太简单了,可以考虑加入更多的特征去优化这一算法。

完整代码

import pandas as pd
import matplotlib.pyplot as plt
import datetime
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

def generate_df_affect_by_n_days(series, n, index=False):
    if len(series) <= n:
        raise Exception("The Length of series is %d, while affect by (n=%d)." % (len(series), n))
    df = pd.DataFrame()
    for i in range(n):
        df['c%d' % i] = series.tolist()[i:-(n - i)]
    df['y'] = series.tolist()[n:]
    if index:
        df.index = series.index[n:]
    return df


def readData(column='收盘价', n=30, all_too=True, index=False, train_end=-300):
    df = pd.read_csv("sh.csv", index_col=0, encoding='utf-8')
    df.fillna(0, inplace=True)
    df.replace(to_replace="None", value=0)
    del df["股票代码"]
    del df["名称"]
    df.index = list(map(lambda x: datetime.datetime.strptime(x, "%Y-%m-%d").date(), df.index))
    df_column = df[column].copy()
    df_column_train, df_column_test = df_column[:train_end], df_column[train_end - n:]
    df_generate_from_df_column_train = generate_df_affect_by_n_days(df_column_train, n, index=index)
    print(df_generate_from_df_column_train)
    if all_too:
        return df_generate_from_df_column_train, df_column, df.index.tolist()
    return df_generate_from_df_column_train


class RNN(nn.Module):
    def __init__(self, input_size):
        super(RNN, self).__init__()
        self.rnn = nn.GRU(
            input_size=input_size,
            hidden_size=64,
            num_layers=1,
            batch_first=True,
        )
        self.out = nn.Sequential(
            nn.Linear(64, 1),
        )
        self.hidden = None
    def forward(self, x):
        r_out, self.hidden = self.rnn(x)  # None 表示 hidden state 会用全0的 state
        out = self.out(r_out)
        return out


class TrainSet(Dataset):
    def __init__(self, data):
        # 定义好 image 的路径
        self.data, self.label = data[:, :-1].float(), data[:, -1].float()

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return len(self.data)


n = 10
LR = 0.0001
EPOCH = 100
train_end = -500
# 数据集建立
df, df_all, df_index = readData('收盘价', n=n, train_end=train_end)
df_all = np.array(df_all.tolist())
plt.plot(df_index, df_all, label='real-data')
df_numpy = np.array(df)
df_numpy_mean = np.mean(df_numpy)
df_numpy_std = np.std(df_numpy)
df_numpy = (df_numpy - df_numpy_mean) / df_numpy_std
df_tensor = torch.Tensor(df_numpy)
trainset = TrainSet(df_tensor)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
rnn = RNN(n)
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.MSELoss()
for step in range(EPOCH):
    for tx, ty in trainloader:
        output = rnn(torch.unsqueeze(tx, dim=0))
        loss = loss_func(torch.squeeze(output), ty)
        optimizer.zero_grad()  # clear gradients for this training step
        loss.backward()  # back propagation, compute gradients
        optimizer.step()
    print(step, loss)
    if step % 10:
        torch.save(rnn, 'rnn.pkl')
torch.save(rnn, 'rnn.pkl')
generate_data_train = []
generate_data_test = []
test_index = len(df_all) + train_end
df_all_normal = (df_all - df_numpy_mean) / df_numpy_std
df_all_normal_tensor = torch.Tensor(df_all_normal)
for i in range(n, len(df_all)):
    x = df_all_normal_tensor[i - n:i]
    x = torch.unsqueeze(torch.unsqueeze(x, dim=0), dim=0)
    y = rnn(x)
    if i < test_index:
        generate_data_train.append(torch.squeeze(y).detach().numpy() * df_numpy_std + df_numpy_mean)
    else:
        generate_data_test.append(torch.squeeze(y).detach().numpy() * df_numpy_std + df_numpy_mean)
plt.plot(df_index[n:train_end], generate_data_train, label='generate_train')
plt.plot(df_index[train_end:], generate_data_test, label='generate_test')
plt.legend()
plt.show()
plt.cla()
plt.plot(df_index[train_end:-400], df_all[train_end:-400], label='real-data')
plt.plot(df_index[train_end:-400], generate_data_test[:-400], label='generate_test')
plt.legend()
plt.show()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【RNN从入门到实战】GRU入门到实战——使用GRU预测股票。 的相关文章

  • 如何在keras中使用Bert作为长文本分类中的段落编码器来实现网络?

    我正在做一个长文本分类任务 文档中有超过 10000 个单词 我计划使用 Bert 作为段落编码器 然后将段落的嵌入逐步输入 BiLSTM 网络如下 输入 batch size max paragraph len max tokens pe
  • 如何使用CNN来训练不同大小的输入数据?

    CNN 似乎主要针对固定大小的输入来实现 现在我想用CNN来训练一些不同大小的句子 有哪些常用的方法 以下建议主要与用于计算机视觉任务 特别是识别 的 CNN 相关 但也可能对您的领域有所帮助 我会看看He 等人的 用于视觉识别的深度卷积网
  • NLTK 中的 FreqDist 未对输出进行排序

    我是 Python 新手 我正在尝试自学语言处理 python 中的 NLTK 有一个名为 FreqDist 的函数 可以给出文本中单词的频率 但由于某种原因它无法正常工作 这是教程让我写的 fdist1 FreqDist text1 vo
  • 日语/字符的编程技巧[关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 我有一个想法 可以编写一些网络应用程序来帮助我 也许还有其他人 更好地学习日语 因为我正在学习日语 我的问题是该网站主要是英文的 所以
  • 如何在 python-gensim 中使用潜在狄利克雷分配(LDA)来抽象二元组主题而不是一元组?

    LDA 原始输出 一元语法 主题1 水肺 水 蒸汽 潜水 主题2 二氧化物 植物 绿色 碳 所需输出 二元组主题 主题1 水肺潜水 水蒸气 主题2 绿色植物 二氧化碳 任何想法 鉴于我有一个名为docs 包含文档中的单词列表 我可以使用 n
  • 如何将地名词典或词典表示为 crf++ 中的特征?

    如何使用地名词典或词典作为功能CRF https taku910 github io crfpp 详细说明 假设我想对人名进行 NER 并且我有一个包含常见人名的地名词典 或字典 我想使用这个地名词典作为 crf 的输入 我该怎么做 我正在
  • 下载变压器模型以供离线使用

    我有一个训练有素的 Transformer NER 模型 我想在未连接到互联网的机器上使用它 加载此类模型时 当前会将缓存文件下载到 cache 文件夹 要离线加载并运行模型 需要将 cache 文件夹中的文件复制到离线机器上 然而 这些文
  • 如何调整 NLTK 句子标记器

    我正在使用 NLTK 来分析一些经典文本 但我在按句子标记文本时遇到了麻烦 例如 这是我从以下内容中得到的片段莫比迪克 http www gutenberg org cache epub 2701 pg2701 txt import nlt
  • BERT - 池化输出与序列输出的第一个向量不同

    我在 Tensorflow 中使用 BERT 有一个细节我不太明白 根据文档 https tfhub dev google bert uncased L 12 H 768 A 12 1 https tfhub dev google bert
  • 使用我自己的训练示例训练 spaCy 现有的 POS 标记器

    我正在尝试在我自己的词典上训练现有的词性标注器 而不是从头开始 我不想创建一个 空模型 在spaCy的文档中 它说 加载您想要统计的模型 下一步是 使用add label方法将标签映射添加到标记器 但是 当我尝试加载英文小模型并添加标签图时
  • 将单引号替换为双引号并排除某些元素

    我想用双引号替换字符串中的所有单引号 但出现的情况除外 例如 n t ll m 等 input the stackoverflow don t said hey what output the stackoverflow don t sai
  • 将 python NLTK 解析树保存到图像文件[重复]

    这个问题在这里已经有答案了 这可能会复制这个 stackoverflowquestion https stackoverflow com questions 23429117 saving nltk drawn parse tree to
  • 语音识别中如何处理同音词?

    对于那些不熟悉什么是同音字 https en wikipedia org wiki Homophone是的 我提供以下示例 我们的 是 嗨和高 到 太 二 在使用时语音API https developer apple com docume
  • 如何确保用户只提交英文文本

    我正在构建一个涉及自然语言处理的项目 由于nlp模块目前只处理英文文本 所以我必须确保用户提交的内容 不长 只有几个单词 是英文的 是否有既定的方法来实现这一目标 首选 Python 或 Javascript 方式 如果内容足够长我会推荐一
  • NLTK 中的 wordnet lemmatizer 不适用于副词 [重复]

    这个问题在这里已经有答案了 from nltk stem import WordNetLemmatizer x WordNetLemmatizer x lemmatize angrily pos r Out 41 angrily 这是 nl
  • SpaCy 中的自定义句子边界检测

    我正在尝试在 spaCy 中编写一个自定义句子分段器 它将整个文档作为单个句子返回 我编写了一个自定义管道组件 它使用以下代码来执行此操作here https github com explosion spaCy issues 1850 但
  • 如何在R中使用OpenNLP获取POS标签?

    这是 R 代码 library NLP library openNLP tagPOS lt function x s lt as String x word token annotator lt Maxent Word Token Anno
  • 举例解释bpe(字节对编码)?

    有人可以帮忙解释一下背后的基本概念吗BPE模型 除了这张纸 https arxiv org abs 1508 07909 目前还没有那么多解释 到目前为止我所知道的是 它通过将罕见和未知的单词编码为子词单元序列来实现开放词汇表上的 NMT
  • 斯坦福 CoreNLP:使用部分现有注释

    我们正在尝试利用现有的 代币化 句子分割 和命名实体标记 同时我们希望使用斯坦福 CoreNlp 额外为我们提供 词性标注 词形还原 和解析 目前 我们正在尝试以下方式 1 为 pos lemma parse 创建一个注释器 Propert
  • 管道:多个流消费者

    我编写了一个程序来计算语料库中 NGram 的频率 我已经有一个函数 它消耗一串令牌并生成一个订单的 NGram ngram Monad m gt Int gt Conduit t m t trigrams ngram 3 countFre

随机推荐

  • visio绘制自制图案并填充

    1 首先打开 文件 gt gt 选项 gt gt 高级 勾选以开发人员模式运行 2 选择开发工具 选择线条 3 绘画出椭圆与线条 4 选择操作 使用修剪功能或得相交得两个部分 5 按照上述步骤绘制出想要得图案 全选后选择操作中的连接 6 对
  • 尝试在条件“($(MsBuildMajorVersion) < 16)”中对计算结果为“”而不是数字的“$(MsBuildMajorVersion)”进行数值比较

    https learn microsoft com en us nuget consume packages migrate packages config to package reference 我的处理方法简单粗爆 直接将packag
  • mysql映射表的作用_php – MySQL:了解映射表

    当使用多对多关系时 处理此事的唯一现实的方法是使用映射表 说我们有一所有老师和学生的学校 一个学生可以有多个教师 反之亦然 所以我们做3个表 student id unsigned integer auto increment primar
  • 【RabbitMq】05 RabbitMq 消息确认机制-可靠抵达

    一 保证消息不丢失 1 使用事务消息 性能下降250倍 2 消息确认机制 1 publisher confirmCallback 确认模式 2 publisher returnCallback 未投递到queue退回模式 3 consume
  • BLE连接建立过程详解

    同一款手机 为什么跟某些设备可以连接成功 而跟另外一些设备又连接不成功 同一个设备 为什么跟某些手机可以建立连接 而跟另外一些手机又无法建立连接 同一个手机 同一个设备 为什么他们两者有时候连起来很快 有时候连起来又很慢 Master是什么
  • 【解决问题】mysql 数据库字符串分割之后多行输出方法

    背景 项目需要从一张表查询出来数据插入到另一张表 其中有一个字段是用逗号分隔的字符串 需要多行输入到另一张表 那么这个如何实现呢 方案 下面先粘贴下sql语句 select SUBSTRING INDEX SUBSTRING INDEX v
  • Windows Server 2008 R2 实现多用户同时登陆

    Server 版系统一直都支持多用户同时登陆 这是一个很好用的功能 我们来看看怎么实现的 Start gt Administrator tools gt Remote Desktop Services gt Remote Desktop S
  • Android事件分发详解

    一 事件分发基础知识讲解 1 事件分发的 事件 是指什么 即 用户触摸屏幕时产生的点击事件 Touch事件 这个点击事件 Touch事件 会被封装成MotionEvent对象 从手指接触屏幕至手指离开屏幕这个过程产生的一系列事件 一般情况下
  • 包邮送50本清华出版社高质量Python、爬虫、机器学习书籍!

    来给大家送一波福利 这次联系了 9个好友一起给各位送书 每个号送 5 本 一共 50本 还包邮哦 具体书籍种类 介绍信息文中有详细介绍 这10个公众号 也是在Python AI 算法 数据科学领域非常优秀的公众号 也能帮助大家学到更多有用知
  • 使用 OpenVINO™及实际应用场景数据集加强已训练好的AI模型 - TensorFlow

    作者 Stewart Christie Ragesh Hajela及李翊玮 机器学习要求我们拥有现有数据 并非应用在运行时将使用的数据 而是要从中学习的数据库 实际上 你需要大量的真实数据 越多越好 您提供的示例越多 计算机应能够学习得越好
  • Linux--Qt Creator 创建桌面快捷方式(Ubuntu 亲测可用)

    在 usr share applications 目录下创建qtcreator desktop vim usr share applications qtcreator desktop 将如下内容复制到qtcreator desktop D
  • linux里安装Nginx

    一 下载 参考 https blog csdn net adaizzz article details 126669430 官网 http nginx org 里面展示了所有版本 根据需要下载你所需的版本即可 将下载好的安装包上传至Linu
  • Tomcat一些漏洞的汇总

    前言 Tomcat 服务器是一个免费的开放源代码的Web 应用服务器 属于轻量级应用 服务器 在中小型系统和并发访问用户不是很多的场合下被普遍使用 是开发和调试JSP 程序的首选 目录 一 Tomcat 任意文件写入 CVE 2017 12
  • 球谐函数原理

    参考 球谐函数极其应用 球谐函数是什么 球谐函数是傅里叶级数的高维类比 由一组表示球体表面的基函数构成 同时 球谐函数是Laplace算子角动量在三个维度上的特征函数 文章目录 球谐函数是什么 球坐标下的Laplace方程 Laplace方
  • [Python+sklearn] 计算混淆矩阵 confusion_matrix()函数

    python sklearn 计算混淆矩阵 confusion matrix 函数 参考sklearn官方文档 sklearn metrics confusion matrix 功能 计算混淆矩阵 以评估分类的准确性 关于混淆矩阵的含义 见
  • TeamViewer——一款强大的远程控制工具

    前言 小编在对一款软件做维护的过程中由于一些地理的原因 所以只能依靠远程控制来进行维护工作 但是远程控制会由于一些原因导致中断 此时只能依靠对方再次同意或者发起远程才能继续工作 那么我们可不可以不需要对方的同意直接进行远程操作呢 答案是肯定
  • [C++]桥接模式

    桥接模式即将抽象部分与它的实现部分分离开来 使他们都可以独立变化 桥接模式将继承关系转化成关联关系 它降低了类与类之间的耦合度 减少了系统中类的数量 也减少了代码量 github源码路径 https github com dangwei 9
  • Git-入门小结

    系列文章 Git 入门小结 Git 分支 Git 常用命令 Git 注册远程仓库 Git安装及上手 一 安装 官网下载安装 我是选择一路默认下去 具体步骤参考此篇博客 https www cnblogs com xueweisuoyong
  • c++类成员函数作为回调函数

    一直以来都被回调函数的定义给整蒙了 最近又仔细学了会 感觉回调函数 我认为就是将一个函数指针A作为参数传入另外一个函数B 然后在函数B中调用函数A 普通回调 具体先看一个简单的例子 include
  • 【RNN从入门到实战】GRU入门到实战——使用GRU预测股票。

    摘要 GRU是LSTM网络的一种效果很好的变体 它较LSTM网络的结构更加简单 而且效果也很好 因此也是当前非常流形的一种网络 GRU既然是LSTM的变体 因此也是可以解决RNN网络中的长依赖问题 在LSTM中引入了三个门函数 输入门 遗忘