TensorFlow在MNIST中的应用-循环神经网络RNN

2023-11-18

参考:

1. 《TensorFlow技术解析与实战》

2. https://www.cnblogs.com/hellcat/p/7401706.html

3. http://www.jianshu.com/p/3dbeb3ab9aa3

########################################################################################

用TensorFlow搭建一个循环神经网络RNN模型,并用来训练MNIST数据集。

RNN在自安然语言处理领域的以下几个方向应用非常成功:

1.机器翻译

2.语音识别

3.图像描述

4.语言单词预测


# -*- coding:utf-8 -*-
# ==============================================================================
# 20171115
# HelloZEX
# 循环神经网络RNN
# Code from https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/recurrent_network.py
# ==============================================================================
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

tf.set_random_seed(1)  # set random seed

# 导入数据
mnist = input_data.read_data_sets("MNIST_Labels_Images", one_hot=True)

# hyperparameters
lr = 0.001  # learning rate
training_iters = 100000  # train step 上限
batch_size = 128
n_inputs = 28  # MNIST data input (img shape: 28*28)
n_steps = 28  # time steps
n_hidden_units = 128  # neurons in hidden layer
n_classes = 10  # MNIST classes (0-9 digits)

# x y placeholder
x = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.float32, [None, n_classes])

# 对 weights biases 初始值的定义
weights = {
    # shape (28, 128)
    'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])),
    # shape (128, 10)
    'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))
}
biases = {
    # shape (128, )
    'in': tf.Variable(tf.constant(0.1, shape=[n_hidden_units, ])),
    # shape (10, )
    'out': tf.Variable(tf.constant(0.1, shape=[n_classes, ]))
}


def RNN(X, weights, biases):
    # 原始的 X 是 3 维数据, 我们需要把它变成 2 维数据才能使用 weights 的矩阵乘法
    # X ==> (128 batches * 28 steps, 28 inputs)
    X = tf.reshape(X, [-1, n_inputs])

    # X_in = W*X + b
    X_in = tf.matmul(X, weights['in']) + biases['in']
    # X_in ==> (128 batches, 28 steps, 128 hidden) 换回3维
    X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units])

    # 使用 basic LSTM Cell.
    lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)
    init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)  # 初始化全零 state
    outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False)

    # 把 outputs 变成 列表 [(batch, outputs)..] * steps
    # list(tensor1, tensor2...)
    outputs = tf.unstack(tf.transpose(outputs, [1, 0, 2]))

    # 这里取的是所有图片最后step计算出的张量组(128,128)
    # 依照RNN的思想,最后的step中包含了前面各组step的信息
    # 应属常识但仍然强调一下的一点是在各中网络框架以及设计中同组batch用的参数相同而且是互不影响的
    results = tf.matmul(outputs[-1], weights['out']) + biases['out']  # 选取最后一个 output
    return results


pred = RNN(x, weights, biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
train_op = tf.train.AdamOptimizer(lr).minimize(cost)

correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    step = 0
    while step * batch_size < training_iters:
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        batch_xs = batch_xs.reshape([batch_size, n_steps, n_inputs])
        sess.run([train_op], feed_dict={
            x: batch_xs,
            y: batch_ys,
        })
        if step % 20 == 0:
            print(sess.run(accuracy, feed_dict={
                x: batch_xs,
                y: batch_ys,
            }))
        step += 1

print('Finish!')
输出:

/usr/bin/python2.7 /home/zhengxinxin/Desktop/PyCharm/Spark/SparkMNIST/SparkMNIST_RNN.py
Extracting MNIST_Labels_Images/train-images-idx3-ubyte.gz
Extracting MNIST_Labels_Images/train-labels-idx1-ubyte.gz
Extracting MNIST_Labels_Images/t10k-images-idx3-ubyte.gz
Extracting MNIST_Labels_Images/t10k-labels-idx1-ubyte.gz
2017-11-15 10:56:00.015122: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
2017-11-15 10:56:00.015152: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
2017-11-15 10:56:00.015157: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2017-11-15 10:56:00.015160: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
2017-11-15 10:56:00.015164: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
0.15625
0.671875
0.804688
0.796875
0.828125
0.84375
0.90625
0.945312
0.945312
0.945312
0.851562
0.90625
0.960938
0.890625
0.90625
0.90625
0.921875
0.945312
0.96875
0.9375
0.921875
0.953125
0.960938
0.929688
0.976562
0.96875
0.960938
0.945312
0.960938
0.976562
0.945312
0.976562
0.929688
0.96875
0.960938
0.945312
0.976562
0.9375
0.96875
0.960938
Finish!

Process finished with exit code 0

把RNN函数提出来单独讲解一下:

输入数据格式应该是[batch,step,input],对应下图中前向传播过程,



(图片引用自参考2)


关键内容:

lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)        # 生成cell核心
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)                                       # 初始cell状态
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False) # rnn计算,会一次性计算出所有的step

对于第三步,实际上tf.contrib.rnn.BasicLSTMCell对象是callable的参数是当前step的input和前一时刻的state,所以每次计算想要一个个step自己迭代的话参看参考3  tensorflow 循环神经网络RNN


这里的思路是就是RNN输入维度正好是一次一张图片(或者说一个数据),最后一step正好包含了本张图片前面所有的信息所以取它的output,正是因此提取时需要改维度为{step,batch,output}以方便提取全图片(batch)的最后一个step的output:

outputs = tf.unstack(tf.transpose(outputs, [1,0,2]))

[注]:其实不需要unstack,这个函数会把拆分后维度(默认为0)的张量改为包含小张量list,去掉了这个函数也不影响后面output[-1]的操作,张量和list都可以接收slice操作


从输出来看其实不如CNN在MNIST上面的效果好,虽然仅差一点点。


本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorFlow在MNIST中的应用-循环神经网络RNN 的相关文章

随机推荐

  • Visual Studio Code 插件

    Visual Studio Code 插件安装 插件安装 Script插件 Vue插件 插件安装 首页点击 工具和语言 如下图 接下来 在输入框中输入想要安装的插件的名字 点击 install 即可进行安装 Script插件 1 Eslin
  • C++中cout和cerr的区别?

    之前一直在用 但就是没在意两者到底有啥却别 今天又想到这个问题 总结下吧 以下的内容均是本人从网上查阅资料看来整理的 暂时还没有查阅官方资料 不保证准确 欢迎讨论 其实大家平常常会用的主要有三个 cout cerr clog 首先简单介绍下
  • http协议的状态码:404等常见网页错误代码

    http协议的状态码 一 1xx 临时响应 表示临时响应并需要请求者继续执行操作的状态码 100 继续 请求者应当继续提出请求 服务器返回此代码表示已收到请求的第一部分 正在等待其余部分 101 切换协议 请求者已要求服务器切换协议 服务器
  • 滴滴夜莺:从监控告警系统向运维平台演化

    简述 滴滴夜莺 Nightingale 是一款经过大规模生产环境验证的 分布式高性能的运维监控系统 基于Open Falcon 结合滴滴内部的最佳实践 在性能 可维护性 易用性方面做了大量的改进 支撑了滴滴内部数十亿监控指标 覆盖了从系统
  • 惊呆了!女儿拿着小天才电话手表,问我Android启动流程!

    首先 new一个女儿 var mDdaughter new 女儿 6岁 漂亮可爱 健康乖巧 最喜欢玩小天才电话手表和她的爸爸 好了 女儿有了 有一天 女儿问我 爸爸爸爸 你说我玩的这个小天才电话手表怎么这么厉害 随便点一下这个小图片 这个应
  • Manifest合并失败几种原因以及解决方法

    今天遇到了一个报错 Error Execution failed for task app processDebugManifest gt Manifest merger failed with multiple errors see lo
  • c语言合并两个单链表LA和LB,把两个递增的单链表La,Lb,合并成一个递减的单链表Lc...

    原文题是严蔚敏同志的数据结构习题中第二章线性表中提出的问题 原问如下 2 24 假设有两个按元素值递增有序排列的线性表A和B 均以单链表作存储结构 请编写算法将A表与B表归并成一个按元素值递减有序 即非递增有序 允许表中含有值相同的元表 排
  • 基于Vue + vuex + Antd-design-vue实现天气App

    simple weather github 地址 github com WqhForGitHu 效果图 PC端 移动设备端 技术框架 该应用是基于 Vue vuex 实现的 页面的 UI 则是使用了 Antd design vue 库来完成
  • Android 版本统一管理

    前言 因为现在项目都比较模块化 组件化 要用到的model比较多 一个model就有一个build gradle文件 里面都有compileSdkVersion或buildToolsVersion等可能出现版本不一致导致编译出现错误 所以要
  • thinkphp5学习路程 三 数据库操作

    首先我用的是php中文网提供的php工具箱 phpmyadmin管理mysql 在此之前最好对sql语句有所了解 会简单的增删改查等 在里面创建数据库和一张表如下 随后你需要打开数据库的配置文件 目录为 application databa
  • Python OpenCV中的图像阈值处理

    1 前言 上一篇介绍了用C 如何对一幅图像进行阈值处理 本篇接着用python来做同样的事情 图像阈值处理是很多高级算法的底层逻辑之一 比如在做图形检测 轮廓识别时 常常会先对图像进行阈值处理 然后再进行具体的检测或识别 因此很有必要掌握图
  • 指针作函数返回值

    include
  • 指向数组的引用 const char(&p)[a]

    指向数组的引用 const char p a 问题起源 如何在函数内 也能获取数组的大小信息 如果是定义一个数组a后 使用如下方法即可获取大小信息 cout lt lt sizeof a sizeof a 0 但是如果作为一个参数传入到一个
  • 最新酒桌小游戏喝酒小程序源码_带流量主源码下载

    2022最新酒桌小游戏喝酒小程序源码 带流量主 喝酒神器3 6 我修改增加了广告位 根据文档直接替换即可 原版本没有广告位 直接上传源码到开发者端即可 通过后改广告代码 然后关闭广告展示提交 通过后打开即可 下载地址 最新酒桌小游戏喝酒小程
  • 在linux系统下安装配置apache服务器

    我所用的是centos linux系统 但apache的服务在linux系统都大同小异 像ubuntu redhat等等 now let us go 如有问题 欢迎直邮 zhe jiang he hp com lt 何哲江 gt 1 获取软
  • Edge浏览器没有让我失望! 今天终于可以在win10中模拟IE内核进行前端测试了!

    前言 ietest现在是不是不好用了 Edge浏览器仿真是不是不见了 如图 如果我们在前端开发javascript遇见一些老旧的语法标准 想要测试一下都难 想想都抓狂 不过不用担心 经过这几天的资料查阅 我还是找到了一个解决办法来模拟旧版I
  • Set集合中的SortedSet接口下的实现类TreeSet

    放入TreeSet集合中的元素必须实现Comparable接口 不然会报错 因为这个集合中的元素会自动按元素的大小顺序排序 所以不是实现比较的接口就会出现ClassCastException 还要注意一点的是Set集合中的元素是不可重读的
  • ctfshow web入门刷题3

    web15 看提示找到邮箱 然后尝试登入后台 url admin 尝试点击忘记密码然后提示输入城市 尝试用qq搜索qq号 发现城市为西安 得到后台密码 登入得到flag WEB16 题目提示php探针 所以url tz php打开探针然后搜
  • 当你穿越到道诡异仙的世界,如何利用密码学知识区分幻想和现实?

    题解 牛群的能量 题目考察的知识点动态规划题目解答方法的文字分析用 f i 代表以第 i个数结尾的 和最大子群能量值之和 设数组的长度为n 则本题的答案时从0到n 1这n个f 题解 牛牛的名字游戏 题目考察的知识点字符串题目解答方法的文字分
  • TensorFlow在MNIST中的应用-循环神经网络RNN

    参考 1 TensorFlow技术解析与实战 2 https www cnblogs com hellcat p 7401706 html 3 http www jianshu com p 3dbeb3ab9aa3 用TensorFlow搭