TensorFlow入门(五)多层 LSTM 通俗易懂版

2023-05-16

欢迎转载,但请务必注明原文出处及作者信息。

@author: huangyongye
@creat_date: 2017-03-09

前言: 根据我本人学习 TensorFlow 实现 LSTM 的经历,发现网上虽然也有不少教程,其中很多都是根据官方给出的例子,用多层 LSTM 来实现 PTBModel 语言模型,比如:
tensorflow笔记:多层LSTM代码分析
但是感觉这些例子还是太复杂了,所以这里写了个比较简单的版本,虽然不优雅,但是还是比较容易理解。

如果你想了解 LSTM 的原理的话(前提是你已经理解了普通 RNN 的原理),可以参考我前面翻译的博客:
(译)理解 LSTM 网络 (Understanding LSTM Networks by colah)

如果你想了解 RNN 原理的话,可以参考 AK 的博客:
The Unreasonable Effectiveness of Recurrent Neural Networks

很多朋友提到多层怎么理解,所以自己做了一个示意图,希望帮助初学者更好地理解 多层RNN.


图1 3层RNN按时间步展开

本例不讲原理。通过本例,你可以了解到单层 LSTM 的实现,多层 LSTM 的实现。输入输出数据的格式。 RNN 的 dropout layer 的实现。

# -*- coding:utf-8 -*-
import tensorflow as tf
import numpy as np
from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import input_data

# 设置 GPU 按需增长
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# 首先导入数据,看一下数据的形式
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
print mnist.train.images.shape
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
(55000, 784)

1. 首先设置好模型用到的各个超参数

lr = 1e-3
# 在训练和测试的时候,我们想用不同的 batch_size.所以采用占位符的方式
batch_size = tf.placeholder(tf.int32)  # 注意类型必须为 tf.int32
# 在 1.0 版本以后请使用 :
# keep_prob = tf.placeholder(tf.float32, [])
# batch_size = tf.placeholder(tf.int32, [])

# 每个时刻的输入特征是28维的,就是每个时刻输入一行,一行有 28 个像素
input_size = 28
# 时序持续长度为28,即每做一次预测,需要先输入28行
timestep_size = 28
# 每个隐含层的节点数
hidden_size = 256
# LSTM layer 的层数
layer_num = 2
# 最后输出分类类别数量,如果是回归预测的话应该是 1
class_num = 10

_X = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, class_num])
keep_prob = tf.placeholder(tf.float32)

2. 开始搭建 LSTM 模型,其实普通 RNNs 模型也一样

# 把784个点的字符信息还原成 28 * 28 的图片
# 下面几个步骤是实现 RNN / LSTM 的关键
####################################################################
# **步骤1:RNN 的输入shape = (batch_size, timestep_size, input_size) 
X = tf.reshape(_X, [-1, 28, 28])

# **步骤2:定义一层 LSTM_cell,只需要说明 hidden_size, 它会自动匹配输入的 X 的维度
lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, forget_bias=1.0, state_is_tuple=True)

# **步骤3:添加 dropout layer, 一般只设置 output_keep_prob
lstm_cell = rnn.DropoutWrapper(cell=lstm_cell, input_keep_prob=1.0, output_keep_prob=keep_prob)

# **步骤4:调用 MultiRNNCell 来实现多层 LSTM
mlstm_cell = rnn.MultiRNNCell([lstm_cell] * layer_num, state_is_tuple=True)

# **步骤5:用全零来初始化state
init_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32)

# **步骤6:方法一,调用 dynamic_rnn() 来让我们构建好的网络运行起来
# ** 当 time_major==False 时, outputs.shape = [batch_size, timestep_size, hidden_size] 
# ** 所以,可以取 h_state = outputs[:, -1, :] 作为最后输出
# ** state.shape = [layer_num, 2, batch_size, hidden_size], 
# ** 或者,可以取 h_state = state[-1][1] 作为最后输出
# ** 最后输出维度是 [batch_size, hidden_size]
# outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False)
# h_state = outputs[:, -1, :]  # 或者 h_state = state[-1][1]

# *************** 为了更好的理解 LSTM 工作原理,我们把上面 步骤6 中的函数自己来实现 ***************
# 通过查看文档你会发现, RNNCell 都提供了一个 __call__()函数(见最后附),我们可以用它来展开实现LSTM按时间步迭代。
# **步骤6:方法二,按时间步展开计算
outputs = list()
state = init_state
with tf.variable_scope('RNN'):
    for timestep in range(timestep_size):
        if timestep > 0:
            tf.get_variable_scope().reuse_variables()
        # 这里的state保存了每一层 LSTM 的状态
        (cell_output, state) = mlstm_cell(X[:, timestep, :], state)
        outputs.append(cell_output)
h_state = outputs[-1]

3. 设置 loss function 和 优化器,展开训练并完成测试

  • 以下部分其实和之前写的 TensorFlow入门(三)多层 CNNs 实现 mnist分类 的对应部分是一样的。
# 上面 LSTM 部分的输出会是一个 [hidden_size] 的tensor,我们要分类的话,还需要接一个 softmax 层
# 首先定义 softmax 的连接权重矩阵和偏置
# out_W = tf.placeholder(tf.float32, [hidden_size, class_num], name='out_Weights')
# out_bias = tf.placeholder(tf.float32, [class_num], name='out_bias')
# 开始训练和测试
W = tf.Variable(tf.truncated_normal([hidden_size, class_num], stddev=0.1), dtype=tf.float32)
bias = tf.Variable(tf.constant(0.1,shape=[class_num]), dtype=tf.float32)
y_pre = tf.nn.softmax(tf.matmul(h_state, W) + bias)


# 损失和评估函数
cross_entropy = -tf.reduce_mean(y * tf.log(y_pre))
train_op = tf.train.AdamOptimizer(lr).minimize(cross_entropy)

correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))


sess.run(tf.global_variables_initializer())
for i in range(2000):
    _batch_size = 128
    batch = mnist.train.next_batch(_batch_size)
    if (i+1)%200 == 0:
        train_accuracy = sess.run(accuracy, feed_dict={
            _X:batch[0], y: batch[1], keep_prob: 1.0, batch_size: _batch_size})
        # 已经迭代完成的 epoch 数: mnist.train.epochs_completed
        print "Iter%d, step %d, training accuracy %g" % ( mnist.train.epochs_completed, (i+1), train_accuracy)
    sess.run(train_op, feed_dict={_X: batch[0], y: batch[1], keep_prob: 0.5, batch_size: _batch_size})

# 计算测试数据的准确率
print "test accuracy %g"% sess.run(accuracy, feed_dict={
    _X: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0, batch_size:mnist.test.images.shape[0]})
Iter0, step 200, training accuracy 0.851562
Iter0, step 400, training accuracy 0.960938
Iter1, step 600, training accuracy 0.984375
Iter1, step 800, training accuracy 0.960938
Iter2, step 1000, training accuracy 0.984375
Iter2, step 1200, training accuracy 0.9375
Iter3, step 1400, training accuracy 0.96875
Iter3, step 1600, training accuracy 0.984375
Iter4, step 1800, training accuracy 0.992188
Iter4, step 2000, training accuracy 0.984375
test accuracy 0.9858

我们一共只迭代不到5个epoch,在测试集上就已经达到了0.9825的准确率,可以看出来 LSTM 在做这个字符分类的任务上还是比较有效的,而且我们最后一次性对 10000 张测试图片进行预测,才占了 725 MiB 的显存。而我们在之前的两层 CNNs 网络中,预测 10000 张图片一共用了 8721 MiB 的显存,差了整整 12 倍呀!! 这主要是因为 RNN/LSTM 网络中,每个时间步所用的权值矩阵都是共享的,可以通过前面介绍的 LSTM 的网络结构分析一下,整个网络的参数非常少。

4. 可视化看看 LSTM 的是怎么做分类的

毕竟 LSTM 更多的是用来做时序相关的问题,要么是文本,要么是序列预测之类的,所以很难像 CNNs 一样非常直观地看到每一层中特征的变化。在这里,我想通过可视化的方式来帮助大家理解 LSTM 是怎么样一步一步地把图片正确的给分类。

import matplotlib.pyplot as plt

看下面我找了一个字符 3

print mnist.train.labels[4]
[ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]

我们先来看看这个字符样子,上半部分还挺像 2 来的

X3 = mnist.train.images[4]
img3 = X3.reshape([28, 28])
plt.imshow(img3, cmap='gray')
plt.show()

这里写图片描述

我们看看在分类的时候,一行一行地输入,分为各个类别的概率会是什么样子的。

X3.shape = [-1, 784]
y_batch = mnist.train.labels[0]
y_batch.shape = [-1, class_num]

X3_outputs = np.array(sess.run(outputs, feed_dict={
            _X: X3, y: y_batch, keep_prob: 1.0, batch_size: 1}))
print X3_outputs.shape
X3_outputs.shape = [28, hidden_size]
print X3_outputs.shape
(28, 1, 256)
(28, 256)
h_W = sess.run(W, feed_dict={
            _X:X3, y: y_batch, keep_prob: 1.0, batch_size: 1})
h_bias = sess.run(bias, feed_dict={
            _X:X3, y: y_batch, keep_prob: 1.0, batch_size: 1})
h_bias.shape = [-1, 10]

bar_index = range(class_num)
for i in xrange(X3_outputs.shape[0]):
    plt.subplot(7, 4, i+1)
    X3_h_shate = X3_outputs[i, :].reshape([-1, hidden_size])
    pro = sess.run(tf.nn.softmax(tf.matmul(X3_h_shate, h_W) + h_bias))
    plt.bar(bar_index, pro[0], width=0.2 , align='center')
    plt.axis('off')
plt.show()

这里写图片描述

在上面的图中,为了更清楚地看到线条的变化,我把坐标都去了,每一行显示了 4 个图,共有 7 行,表示了一行一行读取过程中,模型对字符的识别。可以看到,在只看到前面的几行像素时,模型根本认不出来是什么字符,随着看到的像素越来越多,最后就基本确定了它是字符 3.

好了,本次就到这里。有机会再写个优雅一点的例子,哈哈。其实学这个 LSTM 还是比较困难的,当时写 多层 CNNs 也就半天到一天的时间基本上就没啥问题了,但是这个花了我大概整整三四天,而且是在我对原理已经很了解(我自己觉得而已。。。)的情况下,所以学会了感觉还是有点小高兴的~

17-04-19补充几个资料:
- recurrent_network.py 一个简单的 tensorflow LSTM 例子。
- Tensorflow下构建LSTM模型进行序列化标注 介绍非常好的一个 NLP 开源项目。(例子中有些函数可能在新版的 tensorflow 中已经更新了,但并不影响理解)

5. 附:BASICLSTM.__call__()

'''code: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py'''

  def __call__(self, inputs, state, scope=None):
      """Long short-term memory cell (LSTM)."""
      with vs.variable_scope(scope or "basic_lstm_cell"):
          # Parameters of gates are concatenated into one multiply for efficiency.
          if self._state_is_tuple:
              c, h = state
          else:
              c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
          concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope)

          # ** 下面四个 tensor,分别是四个 gate 对应的权重矩阵
          # i = input_gate, j = new_input, f = forget_gate, o = output_gate
          i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)

          # ** 更新 cell 的状态: 
          # ** c * sigmoid(f + self._forget_bias) 是保留上一个 timestep 的部分旧信息
          # ** sigmoid(i) * self._activation(j)  是有当前 timestep 带来的新信息
          new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
               self._activation(j))

          # ** 新的输出
          new_h = self._activation(new_c) * sigmoid(o)

          if self._state_is_tuple:
              new_state = LSTMStateTuple(new_c, new_h)
          else:
              new_state = array_ops.concat([new_c, new_h], 1)
          # ** 在(一般都是) state_is_tuple=True 情况下, new_h=new_state[1]
          # ** 在上面博文中,就有 cell_output = state[1]
          return new_h, new_state

本文代码:https://github.com/yongyehuang/Tensorflow-Tutorial

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorFlow入门(五)多层 LSTM 通俗易懂版 的相关文章

随机推荐

  • 个人面试经历经验谈

    到昨天接到金蝶得Offer xff0c 我想我为期三个星期的找工作面试之旅应该是告一段落了 原以为接到Offer会有点高兴 xff0c 但是一回味这三个星期的起起落落 xff0c 便实在是高兴不起来 xff0c 虽然手上有好几个Offer可
  • linux 查看端口占用情况

    1 查看系统端口 netstat anptl显示所有正在监听的端口 2 刷选某个端口 netstat anptl grep 39 3350 39 3 查看占用端口的应用程序 ps lt PID gt 下图中可以看到端口8080 的PID 6
  • eclipse tomcat部署项目开发环境修改访问路径

    eclipse tomcat部署项目开发环境修改访问路径
  • 欢迎使用CSDN-markdown编辑器

    欢迎使用Markdown编辑器写博客 本Markdown编辑器使用StackEdit修改而来 xff0c 用它写博客 xff0c 将会带来全新的体验哦 xff1a Markdown和扩展Markdown简洁的语法代码块高亮图片链接和图片上传
  • PLSQL 11 注册码

    PLSQL 11 注册码 注册码 xff1a Product Code xff1a 4t46t6vydkvsxekkvf3fjnpzy5wbuhphqz serial Number xff1a 601769 password xff1a x
  • ORA-28000: the account is locked-的解决办法

    ORA 28000 the account is locked 第一步 xff1a 使用PL SQL xff0c 登录名为system 数据库名称不变 xff0c 选择类型的时候把Normal修改为Sysdba 第二步 xff1a 选择my
  • IOS中bootstrap-select 动态加载的下拉框点击不展示(已解决)

    问题描述 bootstrap select 动态加载的option 在安卓浏览器中能点击后展示 但是在ios浏览器中点击没反应 发现原因 在重新渲染方法后加了一行 s e
  • S7-PLCSIM 无法找到STEP 7 V15 许可证(必须在此计算机上安装STEP V15应用程序)。-----(已解决)

    已经安装过step7 并且已经授权过了 xff0c 但是启动时提示下图错误 记得右键已管理员身份运行
  • 新建springboot项目, pom.xml报错 Unkown error 解决思路

    新建项目pom xml 报错 网上解决思路 xff1a 1 大多都是项目右键 Maven Update Project 选中Force Update of Snapshots Releases 进行强制更新 2 1 5 改成了2 1 3 修
  • 嵌入式软件面试总结

    背景 先说说本人的背景 xff0c 我 xff0c 一个大专人 xff0c 从事嵌入式开发两年了 xff0c 之前在一家公司是负责单片机和物联网开发的 2020年年底我选择了裸辞 xff08 主要想出去玩 xff09 直到春节结束后 xff
  • Intel NUC安装ubuntu系统的方法

    使用intel nuc安装ubuntu系统 xff0c 试验了好多次UEFI安装 xff0c 但是结果都是开机时会出现 A bootable device 除了这句话都是黑屏的现象 原因我查了很多 xff0c 也不敢确定 xff0c 现在总
  • 白骑士的树莓派教学(二):镜像烧录

    本期内容让我们来了解一下树莓派操作系统镜像烧录的操作 xff0c 所需的设备 xff1a PC机 xff0c U盘 xff0c 树莓派相关设备 什么是镜像 xff1f 所谓镜像文件其实和ZIP压缩包类似 xff0c 它将特定的一系列文件按照
  • VS Code Remote SSH远程连接异常:Resolver error: Error: Running the contributed command

    VS Code Remote SSH远程连接异常 问题描述原因分析解决方案扩展Remote SSH首次连接插件做了什么Remote SSH对于远程Linux的要求 问题描述 通过VS Code插件Remote SSH连接一台新主机时 xff
  • PHP常用六大设计模式

    单例模式 特点 xff1a 三私一公 xff1a 私有的静态变量 xff08 存放实例 xff09 xff0c 私有的构造方法 xff08 防止创建实例 xff09 xff0c 私有的克隆方法 防止克隆对象 xff0c 公有的静态方法 xf
  • matlab中文乱码的解决(UTF-8不支持的问题)

    1 解决editor中的UTF 8不支持的问题 xff0c 需要加入下面几行 在matlab 安装的目录的bin子文件夹中找到lcdata xml文件 xff1a 打开加入 lt Locale entries example gt lt l
  • FreeRTOS分析

    freertos是一个轻量级的rtos xff0c 它目前实现了一个微内核 xff0c 并且port到arm7 avr pic18 coldfire等众多处理器上 xff1b 目前已经在rtos的市场上占有不少的份额 它当然不是一个与vxw
  • STM32之FreeRTOS

    学习操作系统 xff0c 我并没有一开始就学习UCOS xff0c 而是选择了FreeRTOS FreeRTOS可以方便地搭建在各个平台上 xff0c 因为汇编相关 xff0c 都已经由官方完成 xff0c 我们要做的仅是添加自己的代码 x
  • FrankMocap Fast monocular 3D Hand and Body Motion Capture by Regression and Intergretion

    paper title FrankMocap Fast monocular 3D Hand and Body Motion Capture by Regression and Intergretion paper link https ar
  • 矩阵中的路径(C++)

    题目 xff1a 请设计一个函数 xff0c 用来判断在一个矩阵中是否存在一条包含某字符串所有字符的路径 路径可以从矩阵中的任意一个格子开始 xff0c 每一步可以在矩阵中向左 xff0c 向右 xff0c 向上 xff0c 向下移动一个格
  • TensorFlow入门(五)多层 LSTM 通俗易懂版

    欢迎转载 xff0c 但请务必注明原文出处及作者信息 64 author huangyongye 64 creat date 2017 03 09 前言 根据我本人学习 TensorFlow 实现 LSTM 的经历 xff0c 发现网上虽然