TensorFlow在MNIST中的应用-卷积神经网络CNN

2023-11-16

参考:

《TensorFlow技术解析与实战》


########################################################################################

用TensorFlow搭建一个卷积神经网络CNN模型,并用来训练MNIST数据集。


# -*- coding:utf-8 -*-
# ==============================================================================
# 20171115
# HelloZEX
# 卷积神经网络
# ==============================================================================

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

#定义训练和评估时的批次大小
batch_size = 128
test_size = 256

#初始化权重函数
def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))

#神经网络模型的构建,传入以下参数
# x:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden;dropout要保留的神经元比例
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
    #第一组卷积和池化层,最后dropout一些神经元
    l1a = tf.nn.relu(tf.nn.conv2d(X, w,                       # l1a shape=(?, 28, 28, 32)
                        strides=[1, 1, 1, 1], padding='SAME'))
    l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1],              # l1 shape=(?, 14, 14, 32)
                        strides=[1, 2, 2, 1], padding='SAME')
    l1 = tf.nn.dropout(l1, p_keep_conv)

    #第二组卷积和池化层,最后dropout一些神经元
    l2a = tf.nn.relu(tf.nn.conv2d(l1, w2,                     # l2a shape=(?, 14, 14, 64)
                        strides=[1, 1, 1, 1], padding='SAME'))
    l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1],              # l2 shape=(?, 7, 7, 64)
                        strides=[1, 2, 2, 1], padding='SAME')
    l2 = tf.nn.dropout(l2, p_keep_conv)

    #第三组卷积和池化层,最后dropout一些神经元
    l3a = tf.nn.relu(tf.nn.conv2d(l2, w3,                     # l3a shape=(?, 7, 7, 128)
                        strides=[1, 1, 1, 1], padding='SAME'))
    l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1],              # l3 shape=(?, 4, 4, 128)
                        strides=[1, 2, 2, 1], padding='SAME')
    l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]])    # reshape to (?, 2048)
    l3 = tf.nn.dropout(l3, p_keep_conv)

    #全连接层,最后dropout一些神经元
    l4 = tf.nn.relu(tf.matmul(l3, w4))
    l4 = tf.nn.dropout(l4, p_keep_hidden)

    #输出层
    pyx = tf.matmul(l4, w_o)
    return pyx

#得到训练和测试的图片
mnist = input_data.read_data_sets("MNIST_Labels_Images", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
trX = trX.reshape(-1, 28, 28, 1)  # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1)  # 28x28x1 input imgp_keep_hidden: 1.0})))

X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])

#初始化权重
w = init_weights([3, 3, 1, 32])       # patch大小为3x3,输入维度为1 ,输出维度为32
w2 = init_weights([3, 3, 32, 64])     # patch大小为3x3,输入维度为32 ,输出维度为64
w3 = init_weights([3, 3, 64, 128])    # patch大小为3x3,输入维度为64 ,输出维度为128
w4 = init_weights([128 * 4 * 4, 625]) # 全连接层
w_o = init_weights([625, 10])         # 输出层,输入维度为625,输出维度为10代表十个分类(labels)

#我们定义dropout的占位符 keep_conv,他表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden)

#定义的损失函数,并作均值处理。采用实现RMSProp算法的优化器
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
#定义预测的操作(predict_op)
predict_op = tf.argmax(py_x, 1)

# Launch the graph in a session
with tf.Session() as sess:
    # you need to initialize all variables
    tf.global_variables_initializer().run()

    for i in range(100):
        training_batch = zip(range(0, len(trX), batch_size), range(batch_size, len(trX)+1, batch_size))
        for start, end in training_batch:
            sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end], p_keep_conv: 0.8, p_keep_hidden: 0.5})

        test_indices = np.arange(len(teX)) # Get A Test Batch
        np.random.shuffle(test_indices)
        test_indices = test_indices[0:test_size]

        print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
                         sess.run(predict_op, feed_dict={X: teX[test_indices],
                                                         Y: teY[test_indices],
                                                         p_keep_conv: 1.0,
                                                         p_keep_hidden: 1.0})))

########################################################################################

/usr/bin/python2.7 /home/zhengxinxin/Desktop/PyCharm/Spark/SparkMNIST/SparkMNIST_CNN.py
Extracting MNIST_Labels_Images/train-images-idx3-ubyte.gz
Extracting MNIST_Labels_Images/train-labels-idx1-ubyte.gz
Extracting MNIST_Labels_Images/t10k-images-idx3-ubyte.gz
Extracting MNIST_Labels_Images/t10k-labels-idx1-ubyte.gz
2017-11-15 09:55:12.123463: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
2017-11-15 09:55:12.123492: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
2017-11-15 09:55:12.123497: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2017-11-15 09:55:12.123500: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
2017-11-15 09:55:12.123503: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
(0, 0.9453125)
(1, 0.97265625)
(2, 0.98828125)
(3, 0.984375)
(4, 0.98046875)
(5, 1.0)
(6, 0.984375)
(7, 0.99609375)
(8, 0.9921875)
(9, 0.99609375)
(10, 0.99609375)
(11, 0.99609375)
(12, 0.98828125)
(13, 0.99609375)
(14, 1.0)
(15, 0.98046875)
(16, 0.9921875)
(17, 0.99609375)
(18, 0.99609375)
(19, 1.0)
(20, 0.98828125)
(21, 0.9921875)
(22, 0.98828125)
(23, 1.0)
(24, 1.0)
(25, 0.98828125)
(26, 1.0)
(27, 0.9921875)
(28, 0.9921875)
(29, 0.9921875)
(30, 0.99609375)
(31, 0.99609375)
(32, 1.0)
(33, 1.0)
(34, 0.9921875)
(35, 0.99609375)
(36, 1.0)
(37, 0.9921875)
(38, 0.984375)
(39, 0.99609375)
(40, 0.9921875)
(41, 0.98828125)
(42, 0.98828125)
(43, 1.0)
(44, 1.0)
(45, 0.9921875)
(46, 1.0)
(47, 1.0)
(48, 0.98828125)
(49, 0.9921875)
(50, 0.99609375)
(51, 0.9921875)
(52, 0.9921875)
(53, 0.98828125)
(54, 0.98828125)
(55, 0.98828125)
(56, 0.98828125)
(57, 0.9921875)
(58, 0.99609375)
(59, 0.99609375)
(60, 0.984375)
(61, 0.99609375)
(62, 0.99609375)
(63, 0.99609375)
(64, 0.99609375)
(65, 0.9921875)
(66, 0.99609375)
(67, 0.99609375)
(68, 0.9765625)
(69, 0.99609375)
(70, 0.9921875)
(71, 0.9921875)
(72, 0.99609375)
(73, 0.9921875)
(74, 0.9921875)
(75, 0.9921875)
(76, 0.98828125)
(77, 0.99609375)
(78, 0.99609375)
(79, 0.99609375)
(80, 0.984375)
(81, 0.9921875)
(82, 0.9921875)
(83, 0.98828125)
(84, 0.9765625)
(85, 0.99609375)
(86, 1.0)
(87, 1.0)
(88, 0.984375)
(89, 0.99609375)
(90, 0.9921875)
(91, 0.9921875)
(92, 0.984375)
(93, 0.9921875)
(94, 0.99609375)
(95, 1.0)
(96, 0.99609375)
(97, 1.0)
(98, 0.99609375)
(99, 0.984375)

Process finished with exit code 0
########################################################################################

上面输出了训练的次数和准确度的关系。可以看到100轮后准确度已经非常高了。通过回归模型和卷积神经网络模型,可以看出卷积神经网络的效果非常好。下一节使用RNN训练MNIST。


本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorFlow在MNIST中的应用-卷积神经网络CNN 的相关文章

随机推荐

  • K8s存储管理——volume、pv、pvc

    目录 介绍 前言 emptyDir存储卷 hostPath存储卷 本地 NFS共享存储卷 PV PVC NFS使用pv pvc 介绍 来自官方文档 存储的管理是一个与计算实例的管理完全不同的问题 PersistentVolume 子系统为用
  • java基础——内存和变量

    目录 前言 一 java的内存 1 栈内存 2 堆内存 3 方法区内存 二 成员变量与局部变量 1 成员变量 2 局部变量 3 成语变量和局部变量的区别 前言 介绍Java的三种内存分配 栈 堆 方法区 成员变量和局部变量 一 java的内
  • 渗透漏洞 Bugku CTF-Web5

    Bugku CTF Web5 一 开启环境 点击链接 二 查看源代码 发现PHP弱类型 三 构造出 payload 提交获得 flag 一 开启环境 点击链接 二 查看源代码 发现PHP弱类型 1 PHP 比较 2 个值是否相等可以用 或
  • 神策学堂“训练营+特训营”,种子学员招募中,来一起出圈呀!

    2020 年难吗 难 失业 瓶颈 焦虑包围着互联网人 面对这个现状 神策学堂准备了一系列精品课程 芒种训练营 高级特训营 让大家逆流 出圈 神策数据分析芒种训练营 突破瓶颈 晋升高阶岗位 3 场直播 6 实战案例 教你用数据高效赋能业务 1
  • oracle数据库找不到主库,Oracle DG 主库丢失归档

    DG 主库丢失归档 主要原因就是备库没有APP呢 主库就误把归档删除了 常见的这种情况都是主库RMAN做备份的时候把归档删除了 丢失归档解决方法 用RMAN 增量备份恢复 还有恢复控制文件 备库 SQL gt select sequence
  • Centos7.4制作简易RPM包

    准备nginx 1 10 1 tar gz 准备php 7 1 7 tar bz2 这两个源码编译tar包 1 准备制作环境 yum y install rpm build 安装rpm build软件 rpmbuild ba xx spec
  • SpringSecurity最全实战讲解

    文章目录 Spring Security 专题 一 基本概念 认证 授权 会话 RBAC模型 二 一个自己实现的权限模型 BasicAuth 三 SpringBoot Security 快速上手 1 项目搭建步骤 2 用SpringBoot
  • AIF360入门教学

    1 AIF360简介 AI Fairness 360 工具包 AIF360 是一个开源软件工具包 可以帮助检测和缓解整个AI应用程序生命周期中机器学习模型中的偏见 在整个机器学习的过程中 偏见可能存在于初始训练数据 创建分类器的算法或分类器
  • MessageDigest(加密)

    MessageDigest类 MessageDigest 类是一个引擎类 它是为了提供诸如 SHA1 或 MD5 等密码上安全的报文摘要功能而设计的 密码上安全的报文摘要可接受任意大小的输入 一个字节数组 并产生固定大小的输出 该输出称为一
  • 使用 Cloudflare Zero Trust 通过 SSH 连接到 GitHub Actions 的 Runner 机器以进行调试

    GitHub Actions 的 Runner Images 包含了很多常用的开发环境 使用它来构建一些软件是很方便的 不过 构建过程难免会遇到问题 而在 GitHub Actions 上进行构建和在本地有很多不同之处 首先 Runner
  • 服务器装系统都会有哪些坑,小白装机避坑——电脑装系统篇 二

    装机系统分区 首先你需要安装好你的固态硬盘 开机 进入系统 一般用的分区工具都是 DiskGenius 这个软件 粗暴的组装 不需要机箱 一台电脑里面只能设置一个盘作为系统盘 也就是我们的主分区 切记 先对硬件进行测试组装 看看能不能正常启
  • 1125 斐波那契数列

    题目描述 输入整数n 输出斐波那契数列的前n项 输入要求 输入一个整数n 1 lt n lt 12 输出要求 输出斐波那契数列的前n项 每个数后面都有空格 输入样例 6 输出样例 1 1 2 3 5 8 提示 斐波那契数列的排列规则为 第1
  • echarts legend文字颜色

    legend textStyle color fft
  • 一个有意思的let面试题

    今天看到一个面试题 let des 我在外边 let obj des 我在里面 foo function console log this des let bar obj foo bar 这个bar 调用后会打印出什么 本以为是考 this
  • 查看微信小程序的appID和secret

    https mp weixin qq com wxopen devprofile action get profile token 1504304474 lang zh CN 转载于 https www cnblogs com fuckin
  • springmvc源码学习(三十)@ControllerAdvice 全局异常处理

    目录 前言 一 示例 二 原理 前言 在请求到达了 DispatcherServlet 的处理流程 进入 doDispatch 以及后续流程处理业务的过程中出现异常 会进入到 processDispatchResult 处理异常 此时 如果
  • C++-- 如何在类外访问一个类中私有的成员变量?

    如何在类外访问一个类中私有的成员变量 我在网上搜答案的时候看到大部分回答都是在类内部创建一个接口 所以此方法我就不再多做赘述 今天我说的是利用指针 边看代码边理解 上代码 class Test private int a 10 int b
  • win32汇编语言实现冒泡排序

    1 背景 现在大多数的大规模程序并不是由汇编语言来编写 原因很简单 因为太耗时了 但是汇编语言仍然被广泛运用在配置硬件设备以及优化程序的执行速度和尺寸大小等方面 特别是在逆向工程方面 更需要深入理解与熟练掌握汇编语言 针对现阶段 看汇编基本
  • unity04 解决导入fbx文件黑模问题

    左上角window gt rendering gt lighting gt new lighting settings gt 勾选auto generating
  • TensorFlow在MNIST中的应用-卷积神经网络CNN

    参考 TensorFlow技术解析与实战 用TensorFlow搭建一个卷积神经网络CNN模型 并用来训练MNIST数据集 coding utf 8 20171115 HelloZEX 卷积神经网络