LSTM详解

2023-05-16

LSTM详解

文章目录

  • LSTM详解
    • 改进
      • 记忆单元
      • 门控机制
    • LSTM结构
    • LSTM的计算过程
      • 遗忘门
      • 输入门
      • 更新记忆单元
    • 输出门
    • LSTM单元的pytorch实现
    • Pytorch中的LSTM
      • 参数
      • 输入
      • 输出
    • 参考与摘录

LSTM是RNN的一种变种,可以有效地解决RNN的梯度爆炸或者消失问题。关于RNN可以参考作者的另一篇文章https://blog.csdn.net/qq_40922271/article/details/120965322

LSTM的改进在于增加了新的记忆单元与门控机制

改进

记忆单元

LSTM进入了一个新的记忆单元 c t c_t ct,用于进行线性的循环信息传递,同时输出信息给隐藏层的外部状态 h t h_t ht。在每个时刻 t t t c t c_t ct记录了到当前时刻为止的历史信息。

门控机制

LSTM引入门控机制来控制信息传递的路径,类似于数字电路中的门,0即关闭,1即开启。

LSTM中的三个门为遗忘门 f t f_t ft输入门 i t i_t it输出门 o t o_t ot

  • f t f_t ft控制上一个时刻的记忆单元 c t − 1 c_{t-1} ct1需要遗忘多少信息
  • i t i_t it控制当前时刻的候选状态 c ~ t \tilde{c}_t c~t有多少信息需要存储
  • o t o_t ot控制当前时刻的记忆单元 c t c_t ct有多少信息需要输出给外部状态 h t h_t ht

下面我们就看看改进的新内容在LSTM的结构中是如何体现的。

LSTM结构

如图一所示为LSTM的结构,LSTM网络由一个个的LSTM单元连接而成。

image-20211025155036314

图一

图二描述了图一中各种元素的图标,从左到右分别为,神经网络 σ 表 示 s i g m o i d \sigma表示sigmoid σsigmoid)、向量元素操作 × \times ×表示向量元素乘, + + +表示向量加),向量传输的方向向量连接向量复制

image-20211025155100987

图二

LSTM 的关键就是记忆单元,水平线在图上方贯穿运行。

记忆单元类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变会很容易。

image-20211025161323857

LSTM的计算过程

遗忘门

image-20211029104524760

在这一步中,遗忘门读取 h t − 1 h_{t-1} ht1 x t x_t xt,经由sigmoid,输入一个在0到1之间数值给每个在记忆单元 c t − 1 c_{t-1} ct1中的数字,1表示完全保留,0表示完全舍弃。

输入门

image-20211025162824862

输入门将确定什么样的信息内存放在记忆单元中,这里包含两个部分。

  1. sigmoid层同样输出[0,1]的数值,决定候选状态 c ~ t \tilde{c}_t c~t有多少信息需要存储
  2. tanh层会创建候选状态 c ~ t \tilde{c}_t c~t

更新记忆单元

随后更新旧的细胞状态,将 c t − 1 c_{t-1} ct1更新为 c t c_t ct

image-20211029104614220

首先将旧状态 c t − 1 c_{t-1} ct1 f t f_t ft相乘,遗忘掉由 f t f_t ft所确定的需要遗忘的信息,然后加上 i t ∗ c ~ t i_t*\tilde{c}_t itc~t,由此得到了新的记忆单元 c t c_t ct

输出门

结合输出门 o t o_t ot将内部状态的信息传递给外部状态 h t h_t ht。同样传递给外部状态的信息也是个过滤后的信息,首先sigmoid层确定记忆单元的那些信息被传递出去,然后,把细胞状态通过 tanh层 进行处理(得到[-1,1]的值)并将它和输出门的输出相乘,最终外部状态仅仅会得到输出门确定输出的那部分。

image-20211029104651629

通过LSTM循环单元,整个网络可以建立较长距离的时序依赖关系,以上公式可以简洁地描述为

[ c ~ t o t i t f t ] = [ t a n h σ σ σ ] ( W [ x t h t − 1 ] + b ) \begin{bmatrix} \tilde{c}_t \\ o_t \\ i_t \\ f_t \end{bmatrix} = \begin{bmatrix} tanh \\ \sigma \\ \sigma \\ \sigma \end{bmatrix} \begin{pmatrix} W \begin{bmatrix} x_t \\ h_{t-1} \end{bmatrix} +b \end{pmatrix} c~totitft=tanhσσσ(W[xtht1]+b)

c t = f t ⊙ c t − 1 + i t ⊙ c ~ t c_t=f_t \odot c_{t-1}+i_t \odot \tilde{c}_t ct=ftct1+itc~t

h t = o t ⊙ t a n h ( c t ) h_t=o_t \odot tanh(c_t) ht=ottanh(ct)

LSTM单元的pytorch实现

下面通过手写LSTM单元加深对LSTM网络的理解

class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, cell_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size # 隐含状态h的大小,也即LSTM单元隐含层神经元数量
        self.cell_size = cell_size # 记忆单元c的大小
        # 门
        self.gate = nn.Linear(input_size+hidden_size, cell_size)
        self.output = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden, cell):
        # 连接输入x与h 
        combined = torch.cat((input, hidden), 1)
        # 遗忘门
        f_gate = self.sigmoid(self.gate(combined))
        # 输入门
        i_gate = self.sigmoid(self.gate(combined))
        z_state = self.tanh(self.gate(combined))
        # 输出门
        o_gate = self.sigmoid(self.gate(combined))
        # 更新记忆单元
        cell = torch.add(torch.mul(cell, f_gate), torch.mul(z_state, i_gate))
        # 更新隐藏状态h
        hidden = torch.mul(self.tanh(cell), o_gate)
        output = self.output(hidden)
        output = self.softmax(output)
        return output, hidden, cell
    
    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

    def initCell(self):
        return torch.zeros(1, self.cell_size)

Pytorch中的LSTM

CLASS torch.nn.LSTM(*args, **kwargs)

参数

  • input_size – 输入特征维数
  • hidden_size – 隐含状态 h h h的维数
  • num_layers – RNN层的个数:(在竖直方向堆叠的多个相同个数单元的层数),默认为1
  • bias – 隐层状态是否带bias,默认为true
  • batch_first – 是否输入输出的第一维为batchsize
  • dropout – 是否在除最后一个RNN层外的RNN层后面加dropout层
  • bidirectional –是否是双向RNN,默认为false
  • proj_size – If > 0, will use LSTM with projections of corresponding size. Default: 0

其中比较重要的参数就是hidden_sizenum_layers,hidden_size所代表的就是LSTM单元中神经元的个数。从知乎截来的一张图,通过下面这张图我们可以看出num_layers所代表的含义,就是depth的堆叠,也就是有几层的隐含层。可以看到output是最后一层layer的hidden输出的组合

image-20211025220509745

输入

 input, (h_0, c_0)
  • input: (seq_len, batch, input_size) 时间步数或序列长度,batch数,输入特征维度。如果设置了batch_first,则batch为第一维
  • h_0: shape(num_layers * num_directions, batch, hidden_size) containing the initial hidden state for each element in the batch. Defaults to zeros if (h_0, c_0) is not provided.
  • c_0: **shape(num_layers * num_directions, batch, hidden_size)**containing the initial cell state for each element in the batch. Defaults to zeros if (h_0, c_0) is not provided.

输出

output, (h_n, c_n)
  • output: (seq_len, batch, hidden_size * num_directions) 包含每一个时刻的输出特征,如果设置了batch_first,则batch为第一维
  • h_n: shape(num_layers * num_directions, batch, hidden_size) containing the final hidden state for each element in the batch.
  • c_n: shape (num_layers * num_directions, batch, hidden_size) containing the final cell state for each element in the batch.

h与c维度中的num_direction,如果是单向循环网络,则num_directions=1,双向则num_directions=2

参考与摘录

https://blog.csdn.net/qq_40728805/article/details/103959254

https://zhuanlan.zhihu.com/p/79064602

https://www.jianshu.com/p/9dc9f41f0b29

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

LSTM详解 的相关文章

随机推荐

  • 自监督学习(self-supervised learning)(20201124)

    看论文总是会看出来一堆堆奇奇怪怪的名词 从远程监督 有监督 半监督 无监督开始 xff0c 最近又看到了一个自监督 首先先对上面的概念进行简述 xff1a 半监督 xff08 semi supervised learning xff09 x
  • mynteye_sdk SDK ubuntu20 编译问题

    1 pcl问题 xff1a usr include pcl 1 10 pcl point types h 508 1 error plus is not a member of pcl traits 508 POINT CLOUD REGI
  • 异常抛出**异常捕获**with用法

    1 对于异常处理 xff0c javascript支持异常处理 xff0c 支持手动抛出异常 需要抛出的时候 xff0c 总是通过throw语句抛出Error对象 语法如下 xff1a throw new Error errorString
  • Java IO流 使用流技术将一张图片从一个目录复制到另一个目录

    题目 xff1a 使用流技术将一张图片从F images目录下 xff0c 复制到D images目录下 复制图片需要使用字节流 xff0c 使用字符流复制会将图片字节码格式进行编码 xff0c 可能会导致图片数据丢失 span class
  • 卡尔曼滤波KF

    KF 根据贝叶斯估计的原理 xff0c 卡尔曼滤波是利用已知系统模型的确定性特性和统计特性等先验知识与观测量获得最有估计 xff0c 在有初始值的情况下 xff0c 从先验值和最新观测数据中得到的新值的加权平均来更新状态估计 1 卡尔曼滤波
  • 解决VS2019提示未能加载项目文件。缺少根元素的错误

    解决VS2019提示未能加载项目文件 缺少根元素的错误 上次win10自动更新关掉了打开的vs xff0c 导致开机后再打开解决方案 xff0c 某个项目一直无法加载 xff0c 提示未能加载项目文件 缺少根元素的错误 迁移报告上显示这个项
  • 实验二 串口通信及中断实验

    一 xff0e 实验目的 xff08 1 xff09 熟悉 MCU 的异步串行通信 Uart 的工作原理 xff08 2 xff09 掌握 Uart 的通信编程方法 xff08 3 xff09 掌握中断的编程方法 xff08 4 xff09
  • 多线程编程入门——C++ 「semaphore.h」

    回顾OS xff0c 发现自己基本没有实际操作过多线程编程 xff0c 所以想从今天开始学习 从Leedcode的 1114 按序打印开始 xff1a 信号量类型 sem t 原型 xff1a extern int sem init P s
  • 使用网络调试助手连接EMQ服务器

    一 使用MQTT协议与服务器建立连接 1 在图纸中的位置输入EMQ服务器地址与EMQ服务器的端口 2 接下来我们来看一下MQTT协议中的CONNECT报文 1 固定报头 byte1为0x10 xff0c 表示向服务器端发送的为CONNECT
  • npm下载以来版本问题 npm ERR! code ERESOLVE

    这里就是提示npm的版本太高 xff0c 这个时候 xff0c 需要我们问一下原来开发人员的npm的版本号 xff0c 在进行npm insdtall g npm 64 版本号 安装vue element admin项目问题补充 xff1a
  • 在机器人中执行完的仿真怎么移植到真实机器人上?

    最近在鱼香ROS上看了一篇文章 xff0c 在自己的号上记录一下用以保存 一 搞清楚数据流图 1 1建图 以常见的功能包来说 xff0c 一般都可以在运行的时候生成对应的话题输入与输出 xff0c cmd vel用于控制gazebo中的小车
  • 算法——最长公共子序列(动态规划)

    给定两个字符串 text1 和 text2 xff0c 返回这两个字符串的最长 公共子序列 的长度 如果不存在 公共子序列 xff0c 返回 0 一个字符串的 子序列 是指这样一个新的字符串 xff1a 它是由原字符串在不改变字符的相对顺序
  • VINF_FUSION编译出现大量的error: ‘CV_CALIB_CB_ADAPTIVE_THRESH’ was not declared in this scope错误

    错误 span class token operator span home span class token operator span lee span class token operator span catkin ws span
  • 在线古诗自动生成器的设计与实现

    在线古诗自动生成器的设计与实现 前言一 算法模型介绍LSTM简介模型框架实验环境实验与分析实验数据集数据集预处理训练过程模型训练结果模型的评估 二 在线古诗生成器的设计与实现系统结构远程服务器的项目部署系统测试 三 成品展示 前言 古诗 x
  • S7503E V7 snmpv3典型组网配置案例(与IMC联动)

    转载来源 xff1a S7503E V7 snmpv3典型组网配置案例 xff08 与IMC联动 xff09 https mp weixin qq com s idTHFiRDRRZX9nkd pOSiA 组网及说明 本案例为S7503E
  • OBS Studio录屏软件安装和使用教程

    OBS Studio 全称Open Broadcaster Software Studio 是一个免费的开源的无水印的且不限制时长的视频录制软件 1 首先百度搜索 34 联想软件商店 34 xff0c 单击带有 官方 的即可或直接打开htt
  • MARKDOWN-插入图片

    MARKDOWN 插入图片 1 markdown是什么 Markdown 是一种轻量级标记语言 xff0c 创始人为约翰 格鲁伯 xff08 John Gruber xff09 它允许人们使用易读易写的纯文本格式编写文档 xff0c 然后转
  • xshell6评估期已过,解决方法

    xshell6评估期已过一般因为下载的版本是evaluation 30天评估 版本 xff0c 是有期限的 xff0c 解决如下 xff1a 1 前往下载地址 xff0c 点击免费授权页面 xff1a 2 填写必填信息 xff0c 邮箱一定
  • JetsonNano配置RealSense D435i运行环境

    JetsonNano配置RealSense D435i运行环境 文章目录 JetsonNano配置RealSense D435i运行环境0 前言1 系统环境2 安装ubuntu源自带的SDK问题及解决问题解决方法 xff1a 3 源码编译S
  • LSTM详解

    LSTM详解 文章目录 LSTM详解改进记忆单元门控机制 LSTM结构LSTM的计算过程遗忘门输入门更新记忆单元 输出门LSTM单元的pytorch实现Pytorch中的LSTM参数输入输出 参考与摘录 LSTM是RNN的一种变种 xff0