循环神经网络RNN以及几种经典模型

2023-10-26

RNN简介

现实世界中,很多元素都是相互连接的,比如室外的温度是随着气候的变化而周期性的变化的、我们的语言也需要通过上下文的关系来确认所表达的含义。但是机器要做到这一步就相当得难了。因此,就有了现在的循环神经网络,他的本质是:拥有记忆的能力,并且会根据这些记忆的内容来进行推断。因此,他的输出就依赖于当前的输入和记忆。

网络结构及原理

循环神经网络的基本结构特别简单,就是将网络的输出保存在一个记忆单元中,这个记忆单元和下一次的输入一起进入神经网络中。

一个最简单的循环神经网络在输入时的结构示意图:

1

RNN 可以被看做是同一神经网络的多次赋值,每个神经网络模块会把消息传递给下一个,我们将这个图的结构展开:

2

根据循环神经网络的结构也可以看出它在处理序列类型的数据上具有天然的优势。因为网络本身就是 一个序列结构,这也是所有循环神经网络最本质的结构。

我们可以用下面的公式来表示循环神经网络的计算方法:

3

总结图:

4

Pytorch中

pytorch 中使用 nn.RNN 类来搭建基于序列的循环神经网络,它的构造函数有以下几个参数:

  • input_size:输入数据X的特征值的数目。
  • hidden_size:隐藏层的神经元数量,也就是隐藏层的特征数量。
  • num_layers:循环神经网络的层数,默认值是 1。
  • bias:默认为 True,如果为 false 则表示神经元不使用 bias 偏移参数。
  • batch_first:如果设置为 True,则输入数据的维度中第一个维度就是 batch 值,默认为 False。默认情况下第一个维度是序列的长度, 第二个维度才是batch,第三个维度是特征数目。
  • dropout:如果不为空,则表示最后跟一个 dropout 层抛弃部分数据,抛弃数据的比例由该参数指定

RNN 中最主要的参数是 input_sizehidden_size,这两个参数务必要搞清楚。其余的参数通常不用设置,采用默认值就可以了。

rnn = torch.nn.RNN(20,50,2)
input = torch.randn(100 , 32 , 20)
h_0 =torch.randn(2 , 32 , 50)
output,hn=rnn(input ,h_0) 
print(output.size(),hn.size())
'''
torch.Size([100, 32, 50]) torch.Size([2, 32, 50])
'''

参考
一文搞懂RNN(循环神经网络)基础篇
RNN
详解循环神经网络(Recurrent Neural Network)

LSTM

Long Short Term Memory Networks 长短期记忆网络

它解决了短期依赖的问题,并且它通过刻意的设计来避免长期依赖问题

思路

原始 RNN 的隐藏层只有一个状态,即h,它对于短期的输入非常敏感。

再增加一个状态,即c,让它来保存长期的状态,称为单元状态(cell state)

把上图按照时间维度展开:

在 t 时刻,LSTM 的输入有三个:

  • 当前时刻网络的输入值 x_t
  • 上一时刻 LSTM 的输出值 h_t-1
  • 上一时刻的单元状态 c_t-1

LSTM 的输出有两个:

  • 当前时刻 LSTM 输出值 h_t
  • 当前时刻的单元状态 c_t

控制长期状态c

LSTM中使用三个控制开关来控制

结构

标准的循环神经网络内部只有一个简单的层结构,而 LSTM 内部有 4 个层结构:

  1. 忘记层:决定状态中丢弃什么信息

  2. tanh层:用来产生更新值的候选项,说明状态在某些维度上需要加强,在某些维度上需要减弱

  3. sigmoid层(输入门层): 它的输出值要乘到tanh层的输出上,起到一个缩放的作用,极端情况下sigmoid输出0说明相应维度上的状态不需要更新

  4. 最后一层决定输出什么,输出值跟状态有关。候选项中的哪些部分最终会被输出由一个sigmoid层来决定。

Pytorch中

pytorch 中使用 nn.LSTM 类来搭建基于序列的循环神经网络,他的参数基本与RNN类似

lstm = torch.nn.LSTM(10, 20,2)
input = torch.randn(5, 3, 10)
h0 =torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, hn = lstm(input, (h0, c0))
print(output.size(),hn[0].size(),hn[1].size())
'''
torch.Size([5, 3, 20]) torch.Size([2, 3, 20]) torch.Size([2, 3, 20])
'''

详解LSTM

GRU

Gated Recurrent Units

GRU 和 LSTM 最大的不同在于 GRU 将遗忘门和输入门合成了一个"更新门",同时网络不再额外给出记忆状态,而是将输出结果作为记忆状态不断向后循环传递,网络的输人和输出都变得特别简单。

Pytorch中

rnn = torch.nn.GRU(10, 20, 2)
input = torch.randn(5, 3, 10)
h_0= torch.randn(2, 3, 20)
output, hn = rnn(input, h0)
print(output.size(),h0.size())
'''
torch.Size([5, 3, 20]) torch.Size([2, 3, 20])
'''

循环网络的向后传播 BPTT

在向前传播的情况下,RNN的输入随着每一个时间步前进。在反向传播的情况下,我们“回到过去”改变权重,因此我们叫它通过时间的反向传播(BPTT)

我们通常把整个序列(单词)看作一个训练样本,所以总的误差是每个时间步(字符)中误差的和。权重在每一个时间步长是相同的(所以可以计算总误差后一起更新)。

  1. 使用预输出和实际输出计算交叉熵误差
  2. 网络按照时间步完全展开
  3. 对于展开的网络,对于每一个时间步计算权重的梯度
  4. 因为对于所有时间步来说,权重都一样,所以对于所有的时间步,可以一起得到梯度(而不是像神经网络一样对不同的隐藏层得到不同的梯度)
  5. 随后对循环神经元的权重进行升级

RNN展开的网络看起来像一个普通的神经网络。反向传播也类似于普通的神经网络,只不过我们一次得到所有时间步的梯度。如果有100个时间步,那么网络展开后将变得非常巨大,所以为了解决这个问题才会出现LSTM和GRU这样的结构。

RNN用于NLP时的储备知识

词嵌入 word embedding

为了让计算机能够能更好地理解我们的语言,建立更好的语言模型,我们需要将词汇进行表征。

在图像分类问题会使用 one-hot 编码。比如LeNet中一共有10个数字0-9,如果这个数字是2的话,它的编码就是 (0,0,1,0, 0,0 ,0,0,0,0),对于分类问题这样表示十分的清楚。

但是在自然语言处理中,因为单词的数目过多比如有 10000 个不同的词,那么使用 one-hot 这样的方式来定义,效率就特别低,并且占用内存,也不能体现单词的词性, one-hot 没办法体现这个特点,所以 必须使用另外一种方式定义每一个单词。

不同的特征来对各个词汇进行表征,相对于不同的特征,不同的单词均有不同的值,这就是词嵌入

词嵌入不仅对不同单词实现了特征化的表示,还能通过计算词与词之间的相似度

实际上是在多维空间中,寻找词向量之间各个维度的距离相似度,我们就可以实现类比推理,比如说夏天和热,冬天和冷,都是有关联关系的。

在 PyTorch 中我们用 nn.Embedding 层来做嵌入词袋模型,Embedding层第一个输入表示我们有多少个词,第二个输入表示每一个词使用多少维度的向量表示。

# an Embedding module containing 10 tensors of size 3
embedding = torch.nn.Embedding(10, 3)
# a batch of 2 samples of 4 indices each
input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
output=embedding(input)
print(output.size())
'''
torch.Size([2, 4, 3])
'''

Beam Search

在生成第一个词的分布后,可以使用贪心搜索会根据我们的条件语言模型挑选出最有可能输出的第一个词语,但是对于贪心搜索算法来说,我们的单词库中有成百到千万的词汇,去计算每一种单词的组合的可能性是不可行的。所以我们使用近似的搜索办法,使得条件概率最大化或者近似最大化的句子,而不是通过单词去实现。

Beam Search(集束搜索)是一种启发式图搜索算法,通常用在图的解空间比较大的情况下,为了减少搜索所占用的空间和时间,在每一步深度扩展的时候,剪掉一些质量比较差的结点,保留下一些质量较高的结点。虽然Beam Search算法是不完全的,但是用于了解空间较大的系统中,可以减少空间占用和时间。

Beam search可以看做是做了约束优化的广度优先搜索,首先使用广度优先策略建立搜索树,树的每层,按照启发代价对节点进行排序,然后仅留下预先确定的个数(Beam width-集束宽度)的节点,仅这些节点在下一层次继续扩展,其他节点被剪切掉。

  1. 将初始节点插入到list中
  2. 将给节点出堆,如果该节点是目标节点,则算法结束;
  3. 否则扩展该节点,取集束宽度的节点入堆。然后到第二步继续循环。
  4. 算法结束的条件是找到最优解或者堆为空。

在使用上,集束宽度可以是预先约定的,也可以是变化的,具体可以根据实际场景调整设定。

Beam Search

注意力模型

对于使用编码和解码的RNN模型,我们能够实现较为准确度机器翻译结果。对于短句子来说,其性能是十分良好的,但是如果是很长的句子,翻译的结果就会变差。
我们人类进行人工翻译的时候,都是一部分一部分地进行翻译,引入的注意力机制,和人类的翻译过程非常相似,其也是一部分一部分地进行长句子的翻译。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

循环神经网络RNN以及几种经典模型 的相关文章

  • Python执行windows cmd函数

    我知道你可以使用 subprocess 通过 Python 脚本运行 Linux 终端命令 subprocess call ls l for linux 但我找不到在 Windows 上做同样事情的方法 subprocess call di
  • 龙卷风网络和线程

    我是 Tornado 和 Python 线程的新手 我想要实现的目标如下 我有一个龙卷风网络服务器 它接受用户的请求 我想在本地存储一些数据 并定期将其作为批量插入写入数据库 import tornado ioloop import tor
  • 在 Django 中处理 subprocess.call()

    我正在开发的应用程序的简单想法是用户给出 Linux 命令 Linux 命令的结果将显示在网络浏览器中 这是我的观点 py from django shortcuts import render to response from djang
  • pip 中的新彩色终端进度条

    我发现新版本的pip Python的包安装程序 有一个彩色进度条来显示下载进度 我怎样才能做到这一点 Like this pip 本身正在使用rich https pypi org project rich 包裹 特别是 他们的进度条文档
  • Boto3 - 打印 AWS 实例平均 CPU 利用率

    我正在尝试仅打印 AWS 实例的平均 CPU 利用率 此代码将打印出 响应 但最后的 for 循环不会打印平均利用率 有人可以帮忙吗 先感谢您 import boto3 import sys from datetime import dat
  • 在 Python 中打开文本文件时出现问题

    这看起来应该很简单 f open C Users john Desktop text txt r 但我收到此错误 Traceback most recent call last File
  • 用户在对话框中输入

    python 中是否有任何库可用于图形用户输入 我知道关于tk但我相信需要一些代码才能做到这一点 我正在寻找最短的解决方案 a input Enter your string here 取而代之的是 我想要一个对话框 以便用户可以在那里输入
  • 实现一个java UDF并从pyspark调用它

    我需要创建一个在 pyspark python 中使用的 UDF 它使用 java 对象进行内部计算 如果它是一个简单的 python 我会做类似的事情 def f x return 7 fudf pyspark sql functions
  • 使用 Pyodbc + UnixODBC + FreeTDS 设置连接设置

    我使用 Pyodbc UnixODBC 和 FreeTDS 进行了设置 但在其中的某个地方设置了一些选项 但我不知道在哪里 根据 SQL Server Management Studio 我的程序在打开连接时发送一些设置 set quote
  • 如何禁用Excel自动识别数字和文本

    我使用 Python 生成了 CSV 文件 但是当我在Excel中打开它时 如果可以转换 Excel会自动将字符串识别为数字 e g 33E105变成33 10 105 这实际上是一个ID 而不是一个数字 如何在打开 CSV 文件时在 Ex
  • 如何加载 caffe 模型并转换为 numpy 数组?

    我有一个 caffemodel 文件 其中包含 ethereon 的 caffe tensorflow 转换实用程序不支持的层 我想生成我的咖啡模型的 numpy 表示 我的问题是 如何将 caffemodel 文件 我还有 prototx
  • pip 安装与本地包具有相同命名空间的包

    我使用的是 Python 3 6 5 通过 miniconda 安装 我的问题是由于我正在安装一个与本地包具有相同命名空间的包 pip 安装此包后 我无法再从本地包导入 我收到一个ModuleNotFoundError错误 如果可能的话 命
  • 如何从另一个 Python 文件将 Uvicorn FastAPI 服务器作为模块运行?

    我想使用 Uvicorn 从不同的 Python 文件运行 FastAPI 服务器 uvicorn模块 main py import uvicorn import webbrowser from fastapi import FastAPI
  • 在matplotlib中绘制曲线连接点

    所以我试图绘制曲线来连接点 这是我正在使用的代码 def hanging line point1 point2 a point2 1 point1 1 np cosh point2 0 np cosh point1 0 b point1 1
  • 导入后属性未添加到模块中

    我做了以下实验室 vagrant ubuntu xenial test tree pack1 init py mod1 py pack2 init py mod2 py mod3 py test py 2 directories 6 fil
  • Python:计算非整数的阶乘

    我想知道是否有一种快速的 Pythonic 的方法来计算非整数的阶乘 例如 3 4 当然 内置的factorial 函数在Math模块可用 但它仅适用于积分 我不关心这里的负数 你想用math gamma x http docs pytho
  • 如何将目录导入为 python 模块

    如果有目录 home project aaa 我知道它是一个Python包 那么 我如何通过知道它的路径来导入这个模块 这意味着 我希望代码能够正常工作 aaa load module home project aaa 我知道的唯一方法是
  • 使用scrapy到json文件只得到一行输出

    好吧 我对一般编程很陌生 并且具体使用 Scrapy 来实现此目的 我编写了一个爬虫来从 pinterest com 上的 pin 获取数据 问题是我以前从我正在抓取的页面上的所有引脚获取数据 但现在我只获取第一个引脚的数据 我认为问题出在
  • 使用 NaN 获取 pandas 系列模式的最快方法

    我需要找到 pandas groupby 对象或单个系列的模式 最常见元素 为此我有以下函数 def get most common srs from collections import Counter import numpy as n
  • 字典条目被覆盖? [复制]

    这个问题在这里已经有答案了 我发现一些输入没有存储在 Python 3 的字典中 运行这段代码 N int input How many lines of subsequent input graph for n in range N st

随机推荐

  • unity在同屏幕显示多Camera并在脚本中修改Viewport Rece

    参考 https www it610 com article 1305219586412548096 htm 参考 https www zhihu com question 41879088 sort created 修改Camera的Vi
  • 开放平台认证方案

    背景 本次的直接起因是第三方那边接入系统后端引起的 第三方方觉得认证要过期比较麻烦 而且要用账号密码去调登录接口去刷token 设计不合理 客观来说 凭本人使用过其它开放平台来说确实有些不一样 常见的一些开放平台 有带web的 一般web能
  • 感知机及算法实现

    1 感知机二类分类的线性分类模型 输入为实例的特征向量 输出为实例的类别 取 1和 1二值 感知机对应于输入空间中将实例划分为正负两类的分离超平面 属于判别模型 感知机学习旨在求出将训练数据进行线性划分的分离超平面 为此导入基于误分类的损失
  • error: use of deleted function

    本文案例仅供参考 出错的代码如下 TEST Test test1 TestImpl impl TestImpl para1 para2 ASSERT EQ jkj impl func 22 33 44 实际应该这样 TEST Test te
  • PyCharm下载包出错

    PyCharm安装成功之后添加所需的包 File gt Settings gt Project 此处是你的Python工作环境 gt Project Interpreter 红色剪头所指 添加需要的包 点开时候出现错误信息 Error lo
  • phpstorm运行php出现502 Bad Gateway

    个人博客开通啦 功能正在逐步完善中 大家可以访问http www codeliu com 记一次心碎的经历 我用的phpstorm10 0 1 XAMPP 今天写完一个php文件后 运行出现502 Bad Gateway的错误 明明上一刻还
  • c语言中的常见数据类型

    一 常见的数据类型包括基本类型 枚举类型 空类型和派生类型 基本类型又包括整型类型 浮点类型 整型类型 基本类型 int 短整型 short int 长整型 long int 双长整型 long long int 字符型 char 布尔型
  • 判断一个字符是否是十六进制

    判断一个字符是否是十六进制 十六进制 hexadecimal 是计算机中数据的一种表示方法 意思是逢十六进一 十六进制数以16为基数 采用的数码是0 1 2 3 4 5 6 7 8 9 A B C D E F 其中A F分别表示十进制数字1
  • JAVA中的异常处理

    一 什么是异常 异常是指在程序执行过程中出现的错误或异常情况 它可能是由于错误的输入 无效的操作 资源不可用等原因引起的 当程序遇到异常时 它会中断当前的执行路径 并转到能够处理该异常的代码块 在 Java 中 异常是以对象的形式表示的 它
  • PID串行多闭环控制与并行多闭环控制的优缺点分析和应用比较

    导言 在自动控制领域 PID控制器是一种经典的控制策略 被广泛应用于各种工业和非工业过程 随着控制系统的复杂性增加 PID串行多闭环控制和PID并行多闭环控制成为解决复杂控制问题的重要方法 本文将从优点和缺点的角度对这两种控制策略进行对比
  • Android基础之Fragment

    目录 前言 一 Fragment简介 二 Fragment的基础使用 1 创建Fragment 2 在Activity中加入Fragment 1 在Activity的layout xml布局文件中静态添加 2 在Activity的 java
  • 数学建模--粒子群算法(PSO)的Python实现

    目录 1 开篇提示 2 算法流程简介 3 算法核心代码 4 算法效果展示 1 开篇提示 开篇提示 这篇文章是一篇学习文章 思路和参考来自 https blog csdn net weixin 42051846 article details
  • 宝峰对讲机16频率表_宝峰888S对讲机的16个信道频率是多少?

    1 宝峰888S对讲机 16个工作频率范围为 400 470MHZ 16个信道 频率范围内 任意频道任意频率 内 2 一般对讲机没容有固定频点 出厂都是空频机器 每个信道的频率都可以写成机器频率范围内的任意频点也可以空白什么都不写 3 根据
  • 矩阵求逆四种方法

    注 用A B表示某矩阵 E表示单位矩阵 用A 表示A逆 用 A 表示A的行列式 A E 表示拼接矩阵 一 公式法 先求A行列式结果 再求A伴随矩阵 最后再求A逆矩阵 A 0 则 A A A 注 图片中detA就是 A 二 初等变换法 A E
  • 【沧海拾昧】Proteus8仿真stm32:ADC转换程序

    C0102 沧海茫茫千钟粟 且拾吾昧一微尘 沧海拾昧集 CuPhoenix 阅前敬告 沧海拾昧集仅做个人学习笔记之用 所述内容不专业不严谨不成体系 如有问题必是本集记录有谬 切勿深究 目录 一 原理图绘制 二 多位七段数码管 三 ADC引脚
  • 一维动态规划总结

    题目列表 给一个N 输入 求某种情况的最大值或者最小值情况 279 Perfect Squares 思路 最差情况下 总体是定义一个dp N 1 或者初始化前面dp 0 或者dp 1 279 Perfect Squares 解析 Given
  • sql:command not found

    写一个脚本zl sh 用来删除数据库mydatabase中某个表mytable的某行数据 bin bash HOSTNAME 127 0 0 1 PORT 2918 USERNAME root PASSWORD root TABLENAME
  • 使用mockjs创建假数据

    npm install mockjs 创建mock文件夹 在mock文件夹下创建1 js 1 js import Mock from mockjs 引入mockjs export default Mock mock postdata1 po
  • 剑网三服务器缺少必要启动文件,win7系统玩剑网三游戏经常掉线的解决方法

    很多小伙伴都遇到过win7系统玩剑网三游戏经常掉线的困惑吧 一些朋友看过网上零散的win7系统玩剑网三游戏经常掉线的处理方法 并没有完完全全明白win7系统玩剑网三游戏经常掉线是如何解决的 今天小编准备了简单的解决办法 只需要按照1 掉线基
  • 循环神经网络RNN以及几种经典模型

    RNN简介 现实世界中 很多元素都是相互连接的 比如室外的温度是随着气候的变化而周期性的变化的 我们的语言也需要通过上下文的关系来确认所表达的含义 但是机器要做到这一步就相当得难了 因此 就有了现在的循环神经网络 他的本质是 拥有记忆的能力