LSTM原理及实现

2023-11-15

 

前面我们介绍了RNN,现在我们来介绍一种特殊的RNN结构,LSTM网络。我们将逐步介绍LSTM的结构,原理,以及利用LSTM识别手写数字的demo跟深刻的理解LSTM。

LSTM网络

long short term memory,即我们所称呼的LSTM,是为了解决长期以来问题而专门设计出来的,所有的RNN都具有一种重复神经网络模块的链式形式。在标准RNN中,这个重复的结构模块只有一个非常简单的结构,例如一个tanh层。

LSTM 同样是这样的结构,但是重复的模块拥有一个不同的结构。不同于单一神经网络层,这里是有四个,以一种非常特殊的方式进行交互。

不必担心这里的细节。我们会一步一步地剖析 LSTM 解析图。现在,我们先来熟悉一下图中使用的各种元素的图标。

在上面的图例中,每一条黑线传输着一整个向量,从一个节点的输出到其他节点的输入。粉色的圈代表 pointwise 的操作,诸如向量的和,而黄色的矩阵就是学习到的神经网络层。合在一起的线表示向量的连接,分开的线表示内容被复制,然后分发到不同的位置。

LSTM核心思想

LSTM的关键在于细胞的状态整个(绿色的图表示的是一个cell),和穿过细胞的那条水平线。

细胞状态类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变会很容易。

若只有上面的那条水平线是没办法实现添加或者删除信息的。而是通过一种叫做 门(gates) 的结构来实现的。

门 可以实现选择性地让信息通过,主要是通过一个 sigmoid 的神经层 和一个逐点相乘的操作来实现的。

sigmoid 层输出(是一个向量)的每个元素都是一个在 0 和 1 之间的实数,表示让对应信息通过的权重(或者占比)。比如, 0 表示“不让任何信息通过”, 1 表示“让所有信息通过”。

LSTM通过三个这样的本结构来实现信息的保护和控制。这三个门分别输入门、遗忘门和输出门。

逐步理解LSTM

现在我们就开始通过三个门逐步的了解LSTM的原理

遗忘门

在我们 LSTM 中的第一步是决定我们会从细胞状态中丢弃什么信息。这个决定通过一个称为忘记门层完成。该门会读取ht−1ht−1和xtxt,输出一个在 0到 1之间的数值给每个在细胞状态 Ct−1Ct−1 中的数字。1 表示“完全保留”,0 表示“完全舍弃”。

让我们回到语言模型的例子中来基于已经看到的预测下一个词。在这个问题中,细胞状态可能包含当前主语的性别,因此正确的代词可以被选择出来。当我们看到新的主语,我们希望忘记旧的主语。

其中ht−1ht−1表示的是上一个cell的输出,xtxt表示的是当前细胞的输入。σσ表示sigmod函数。

输入门

下一步是决定让多少新的信息加入到 cell 状态 中来。实现这个需要包括两个 步骤:首先,一个叫做“input gate layer ”的 sigmoid 层决定哪些信息需要更新;一个 tanh 层生成一个向量,也就是备选的用来更新的内容,C^tC^t 。在下一步,我们把这两部分联合起来,对 cell 的状态进行一个更新。

现在是更新旧细胞状态的时间了,Ct−1Ct−1更新为CtCt。前面的步骤已经决定了将会做什么,我们现在就是实际去完成。

我们把旧状态与ftft相乘,丢弃掉我们确定需要丢弃的信息。接着加上it∗C~tit∗C~t。这就是新的候选值,根据我们决定更新每个状态的程度进行变化。

在语言模型的例子中,这就是我们实际根据前面确定的目标,丢弃旧代词的性别信息并添加新的信息的地方。

输出门

最终,我们需要确定输出什么值。这个输出将会基于我们的细胞状态,但是也是一个过滤后的版本。首先,我们运行一个 sigmoid 层来确定细胞状态的哪个部分将输出出去。接着,我们把细胞状态通过 tanh 进行处理(得到一个在 -1 到 1 之间的值)并将它和 sigmoid 门的输出相乘,最终我们仅仅会输出我们确定输出的那部分。

在语言模型的例子中,因为他就看到了一个 代词,可能需要输出与一个 动词 相关的信息。例如,可能输出是否代词是单数还是负数,这样如果是动词的话,我们也知道动词需要进行的词形变化。

LSTM变体

原文这部分介绍了 LSTM 的几个变种,还有这些变形的作用。在这里我就不再写了。有兴趣的可以直接阅读原文。

下面主要讲一下其中比较著名的变种 GRU(Gated Recurrent Unit ),这是由 Cho, et al. (2014) 提出。在 GRU 中,如下图所示,只有两个门:重置门(reset gate)和更新门(update gate)。同时在这个结构中,把细胞状态和隐藏状态进行了合并。最后模型比标准的 LSTM 结构要简单,而且这个结构后来也非常流行。

其中, rtrt表示重置门,ztzt表示更新门。重置门决定是否将之前的状态忘记。(作用相当于合并了 LSTM 中的遗忘门和传入门)当rtrt趋于0的时候,前一个时刻的状态信息ht−1ht−1会被忘掉,隐藏状态h^th^t会被重置为当前输入的信息。更新门决定是否要将隐藏状态更新为新的状态h^th^t(作用相当于 LSTM 中的输出门) 。

和 LSTM 比较一下: 
- GRU 少一个门,同时少了细胞状态CtCt。 
- 在 LSTM 中,通过遗忘门和传入门控制信息的保留和传入;GRU 则通过重置门来控制是否要保留原来隐藏状态的信息,但是不再限制当前信息的传入。 
- 在 LSTM 中,虽然得到了新的细胞状态 Ct,但是还不能直接输出,而是需要经过一个过滤的处理:ht=ot∗tanh(Ct)ht=ot∗tanh(Ct);同样,在 GRU 中, 虽然我们也得到了新的隐藏状态h^th^t, 但是还不能直接输出,而是通过更新门来控制最后的输出:ht=(1−zt)∗ht−1+zt∗h^tht=(1−zt)∗ht−1+zt∗h^t

多层LSTM

多层LSTM是将LSTM进行叠加,其优点是能够在高层更抽象的表达特征,并且减少神经元的个数,增加识别准确率并且降低训练时间。具体信息参考[3]

LSTM实现手写数字

这里我们利用的数据集是tensorflow提供的一个手写数字数据集。该数据集是一个包含n张28*28的数据集。

设置LSTM参数

# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.contrib import rnn

import numpy as np
import input_data

# configuration
#                        O * W + b -> 10 labels for each image, O[? 28], W[28 10], B[10]
#                       ^ (O: output 28 vec from 28 vec input)
#                       |
#      +-+  +-+       +--+
#      |1|->|2|-> ... |28| time_step_size = 28
#      +-+  +-+       +--+
#       ^    ^    ...  ^
#       |    |         |
# img1:[28] [28]  ... [28]
# img2:[28] [28]  ... [28]
# img3:[28] [28]  ... [28]
# ...
# img128 or img256 (batch_size or test_size 256)
#      each input size = input_vec_size=lstm_size=28

# configuration variables
input_vec_size = lstm_size = 28 # 输入向量的维度
time_step_size = 28 # 循环层长度

batch_size = 128
test_size = 256

这里设置将batch_size设置为128,time_step_size表示的是lstm神经元的个数,这里设置为28个(和图片的尺寸有关?),input_vec_size表示一次输入的像素数。

初始化权值参数

def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))

def model(X, W, B, lstm_size):
    # X, input shape: (batch_size, time_step_size, input_vec_size)
    # XT shape: (time_step_size, batch_size, input_vec_size)
    #对这一步操作还不是太理解,为什么需要将第一行和第二行置换
    XT = tf.transpose(X, [1, 0, 2])  # permute time_step_size and batch_size,[28, 128, 28]
    # XR shape: (time_step_size * batch_size, input_vec_size)
    XR = tf.reshape(XT, [-1, lstm_size]) # each row has input for each lstm cell (lstm_size=input_vec_size)

    # Each array shape: (batch_size, input_vec_size)
    X_split = tf.split(XR, time_step_size, 0) # split them to time_step_size (28 arrays),shape = [(128, 28),(128, 28)...]
    # Make lstm with lstm_size (each input vector size). num_units=lstm_size; forget_bias=1.0
    lstm = rnn.BasicLSTMCell(lstm_size, forget_bias=1.0, state_is_tuple=True)

    # Get lstm cell output, time_step_size (28) arrays with lstm_size output: (batch_size, lstm_size)
    # rnn..static_rnn()的输出对应于每一个timestep,如果只关心最后一步的输出,取outputs[-1]即可
    outputs, _states = rnn.static_rnn(lstm, X_split, dtype=tf.float32)  # 时间序列上每个Cell的输出:[... shape=(128, 28)..]
    # tanh activation
    # Get the last output
    return tf.matmul(outputs[-1], W) + B, lstm.state_size # State size to initialize the state

init_weigths函数利用正态分布随机生成参数的初始值,model的四个参数分别为:X为输入的数据,W表示的是28*10的权值(标签为0-9),B表示的是偏置,维度和W一样。这里首先将一批128*(28*28)的图片放进神经网络。然后进行相关的操作(注释已经写得很明白了,这里就不再赘述),然后利用WX+B求出预测结果,同时返回lstm的尺寸

训练

py_x, state_size = model(X, W, B, lstm_size)

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)

然后通过交叉熵计算误差,反复训练得到最优值。

源代码: 
https://github.com/geroge-gao/deeplearning/tree/master/LSTM

参考资料

[1].https://www.jianshu.com/p/9dc9f41f0b29

[2].http://blog.csdn.net/Jerr__y/article/details/58598296 
[3].Stacked Long Short-Term Memory Networks 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

LSTM原理及实现 的相关文章

  • nohup 后台启动程序,并输出到指定日志

    1 启动程序并输入到指定日志 nohup python manage py runserver 0 0 0 0 9090 gt data zyj xadstat xadstat log 2 gt 1 或者 nohup python mana
  • 企培版edusoho对接第三方云视频点播 最新版本代码披露 支持m3u8视频加密

    edusoho企培系列版本更新日志 新增功能和优化历史 倍数播放功能 视频分类 支持m3u8视频加密 plugins AliVideoPlugin DependencyInjection Configuration php
  • 零基础入门网络安全,收藏这篇不迷茫【2023 最新】

    零基础入门网络安全 收藏这篇不迷茫 2023 最新 前言 最近收到不少关注朋友的私信和留言 大多数都是零基础小友入门网络安全 需要相关资源学习 其实看过的铁粉都知道 之前的文里是有过推荐过的 新来的小友可能不太清楚 这里就系统地叙述一遍 0
  • Qt connect的实现原理

    概述 connect实质上是将对象A的信号和对象B的槽函数进行连接 然后返回一个句柄Connection 正文 下面通过源码来解析一下 注意看中文注释 connection表示信号槽连接句柄 QMetaObject Connection Q
  • 15. 从0开始学ARM-位置无关码

    目录 十九 位置无关码 一 为什么需要位置无关码 1 exynos 4412启动流程 二 怎么实现位置无关码 1 什么是 编译地址 什么是 运行地址 2 举例 3 代码 四 总结 1 位置无关码 2 位置相关码 3 位置无关码的应用 4 结
  • 动态实体类方案1.0(虚拟实体类生成器)[万能实体]

    该工具能实现任何实体类的动态生成传入参数名称自动生成get set方法 供反射调用 该方法生成的实体类是在程序运行过程动态生成加载出来的 实际代码文件并不存在 所以我暂定他为虚拟实体类生成工具 本方法我自己暂时用在mybatis中当统一的传
  • DETR源码学习(三)之损失函数与后处理

    在DETR模型中 在完成DETR模型的构建后 我们送入数据在完成前向传播后就需要使用预测值与真实值进行计算损失来进行反向传播进而更新梯度 在DETR模型中 其标签匹配采用的是匈牙利匹配算法 主要涉及models matcher py mod
  • 调用百度API实现人脸识别

    人脸识别 听着很高级 但实际上它确实很高级 不过对于我们开发人员来说 我们大部分人都是拿来主义 这次展示的是调用百度人脸识别API进行人脸信息分析 笔者试了下 发现还是挺准确的 而且代码量很少才8行 用的python 如果用java铁定不止
  • js拼接字符串与变量

    使用eval 方法可将拼接后的字符串与变量转变为变量 var field test 我是小白鼠一号 var field test 我是小白鼠二号 然后在JS里尝试将前面的语言简写当成变量 拼接后面的字符串 var lang field va
  • 含泪整理最优质Fbx 3d模型素材,你想要的这里都有

    今天小编针对Fbx 3d模型素材为大家整理了很多内容哦 肯定有需要的小伙伴吧 实用 免费 优质的素材谁又不心动呢 赶紧码住 接下来就给大家介绍一下我珍藏已久的网站 我的工作灵感都是来源它哦 里面的Fbx 3d模型资源数量多 种类丰富 并且每
  • Ubuntu16.04搭建Fabric1.4环境

    一 换源 为了提高下载速度 将ubuntu的源改成国内的源 推荐阿里云源和清华源 apt源保存在 etc apt sources list 代表根目录 etc 这个文件夹几乎放置了系统的所有配置文件 1 备份 sudo cp etc apt
  • shell基础+强化

    shell脚本 一 shell介绍 什么是shell shell功能 1 什么是shell shell是一个程序 采用C语言编写 是用户和Linux内核沟通的桥梁 它既是一种命令语言 又是一种解释性的编程语言 通过一个图标来查看以下设立了的
  • Codeforces 1469 F. Power Sockets —— 二分+线段树,贪心

    This way 题意 现在有一个根节点 和n条包含a i 个节点的链 一开始所有点的颜色是白色的 你每次可以做以下操作 找到树中某个白色节点 拿出一条链 将这个节点和链上某个节点连接 并且这两个点的颜色变成黑色 之后这条链属于树中一个部分
  • 正则表达式中的“^“这个符号的一些思考

    在学习正则表达式的时候 一些常见的规则我们都不难理解 但是有 一个正则表达式中的特殊字符让我一直有点搞不懂 就是 这个字符 文档上给出了解释是匹配输入的开始 如果多行标示被设置成了true 同时会匹配后面紧跟的字符 比如 A 会匹配 An
  • A Survey of Knowledge-Enhanced Pre-trained Language Models

    本文是LLM系列的文章 针对 A Survey of Knowledge Enhanced Pre trained Language Models 的翻译 知识增强的预训练语言模型综述 摘要 1 引言 2 背景 3 KE PLMs用于NLU

随机推荐

  • Java多线程:解决高并发环境下数据插入重复问题(干货)

    每日一更 最近的问题真是一个接一个 真的让人头大 昨天遇到一个多线程的问题问题描述一下 有一个线程的问题 就是假如 我有一个文件 然后这个文件有很多条数据 假如有两个字段 一个学号一个钱 我的需求是 读取文件 把数据插入到表里 先拿文件的学
  • C# 获取电脑CPU序列号

    ManagementClass 的作用域为 System Management 命名空间 System Management private static string GetCPUSerialNum string cpuSerialNum
  • FRP-内网穿透-frps服务端-WEB管理面板-Dashboard

    FRP 内网穿透 frps服务端 WEB管理面板 Dashboard 1 启动云服务端frps 2 修改配置文件 3 启动 刷新服务 4 异常处理 5 SSH web服务 管理界面 DNS Unix 文件服务 https 安全暴露 点对点内
  • C/C++编译时的Link.EXE错误问题与解决方法

    C C 编译时的Link EXE错误问题与解决方法 作者 Acharlix 1 LIBCD lib wincrt0 obj error LNK2001 unresolved external symbol WinMain 16 问题描述er
  • 第二章排错的工具:调试器Windbg(下)

    感谢博主 http book 51cto com art 200711 59874 htm 2 2 读懂机器的语言 汇编 CPU执行指令的最小单元 2 2 1 需要用汇编来排错的常见情况 汇编是CPU执行指令的最小单元 下面一些情况下 汇编
  • ES Module的基本用法

    import 导入 6种 export app js import from 必须在文件的最顶层 最外层的作用域 路径可以是相对路径 或根目录下的绝对路径 或应完整的url 说明可以引用cdn上的一些文件 但不能以字母开头 js会以为是加载
  • 【Unity开源项目精选】ML-Agents:给你的游戏加入AI

    洪流学堂 让你快人几步 你好 我是你的技术探路者郑洪智 你可以叫我大智 今天给你分享一个Unity开源项目 希望对你有帮助哦 ML Agents Unity机器学习代理工具包 ML Agents 是一个开源项目 它使游戏和仿真能够作为培训智
  • Go语言的TCP和HTTP网络服务基础

    目录 TCP Socket 编程模型 Socket读操作 HTTP网络服务 HTTP客户端 HTTP服务端 TCP IP 网络模型实现了两种传输层协议 TCP 和 UDP 其中TCP 是面向连接的流协议 为通信的两端提供稳定可靠的数据传输服
  • Dump文件分析 - PDB不匹配的情景

    Dump文件分析 PDB不匹配的情景 WinDbg 一 运行程序产生dump 二 WinDbg 基于地址偏移量计算异常地址 方法一 三 WinDbg 强制加载pdb 方法二 参考 总结 WinDbg Windows 调试程序 WinDbg
  • Exps on March 23rd

    电话 固定电话 telephone手机 cellphone mobilephone无绳电话 cordless phone公共电话 paying phone长途电话 long distance call国际电话 international c
  • 「远程开发」VSCode使用SSH远程linux服务器 - 公网远程连接

    文章目录 前言 视频教程 1 安装OpenSSH 2 vscode配置ssh 3 局域网测试连接远程服务器 4 公网远程连接 4 1 ubuntu安装cpolar内网穿透 4 2 创建隧道映射 4 3 测试公网远程连接 5 配置固定TCP端
  • Linux安装mysql5.7.23设置密码问题

    问题 安装mysql没有设置密码导致无法进入mysql 系统 ubuntu 18 04 mysql版本 mysql Ver 14 14 Distrib 5 7 23 for Linux x86 64 using EditLine wrapp
  • 【Linux】HTTPS协议

    目录 前言 HTTPS协议原理 1 概念 2 加密和解密 3 常见加密方式 3 1 对称加密 3 2 非对称加密 4 数据摘要和数据指纹 5 HTTPS工作原理 5 1 方案一 仅对称加密 5 2 方案二 仅非对称加密 5 3 方案三 双方
  • pandas---数据处理(csv文件)

    近期在弄一个项目的前期数据 所以总结了一下 内容如下 以下以csv文件为例 1 DataFrame常用操作 1 1 DataFrame去除空行 1 对于一般空行 2 对于列表式 list 空行 1 2 数据的填充 1 表格中填充0 1 3
  • Springboot actuator端点配置与及基本说明2.2.4版

    pom配置
  • 数据结构与算法之美(01)为什么要学习数据结构和算法?

    你是不是觉得数据结构和算法 跟操作系统 计算机网络一样 是脱离实际工作的知识 可能除了面试 这辈子也用不着 尽管计算机相关专业的同学在大学都学过这门课程 甚至很多培训机构也会培训这方面的知识 但是据我了解 很多程序员对数据结构和算法依旧一窍
  • 解决git中出现的“fatal ‘xxxx‘ does not appear to be a git repository”错误的方法

    今天来分享一下我在使用git中出现的一个错误提示 话不多说 我们直接来分析 这个错误是我在通过SSH方式pull远程仓库时候出现的 错误提示如下 fatal xxx 你的仓库别名 does not appear to be a git re
  • 使用yum命令安装jdk1.8没有jps命令

    问题 使用yum命令安装jdk1 8后 不能使用jps 这是由于没有openjdk devel这个包 使用yum命令下载 yum install java 1 8 0 openjdk devel x86 64 下载完成之后就可以使用jps命
  • Leetcode 刷题笔记(二十八) ——动态规划篇之子序列问题:连续子序列和不连续子序列

    文章目录 系列文章目录 前言 题录 53 最大子数组和 674 最长连续递增序列 300 最长递增子序列 718 最长重复子数组 1143 最长公共子序列 1035 不相交的线 系列文章目录 一 数组类型解题方法一 二分法 二 数组类型解题
  • LSTM原理及实现

    LSTM网络 LSTM核心思想 逐步理解LSTM 遗忘门 输入门 输出门 LSTM变体 多层LSTM LSTM实现手写数字 设置LSTM参数 初始化权值参数 训练 参考资料 前面我们介绍了RNN 现在我们来介绍一种特殊的RNN结构 LSTM