深度学习之手写数字识别

2023-11-15

      当我们开始学习编程的时候,第一件事往往是学习打印"Hello World"。就好比编程入门有Hello World,机器学习入门有MNIST。      

     

MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片:

它也包含每一张图片对应的标签,告诉我们这个是数字几。比如,上面这四张图片的标签分别是5,0,4,1。

MNIST数据集

MNIST数据集的官网是Yann LeCun's website。在这里,我们提供了一份python源代码用于自动下载和安装这个数据集。你可以下载这份代码,然后用下面的代码导入到你的项目里面,也可以直接复制粘贴到你的代码文件里面。

import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

下载下来的数据集被分成两部分:60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)。这样的切分很重要,在机器学习模型设计时必须有一个单独的测试数据集不用于训练而是用来评估这个模型的性能,从而更加容易把设计的模型推广到其他数据集上(泛化)。

正如前面提到的一样,每一个MNIST数据单元有两部分组成:一张包含手写数字的图片和一个对应的标签。我们把这些图片设为“xs”,把这些标签设为“ys”。训练数据集和测试数据集都包含xs和ys,比如训练数据集的图片是 mnist.train.images ,训练数据集的标签是 mnist.train.labels


当你安装了tensorflow后,tensorflow自带的教程演示了如何使用卷积神经网络来识别手写数字,tensorFlow会自己下载图片训练集。下面这两行代码 会自动创建一个 'MNIST_data' 的目录来存储数据。
import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)



      这里,mnist是一个轻量级的类。它以Numpy数组的形式存储着训练、校验和测试数据集。同时提供了一个函数,用于在迭代中获得minibatch,后面我们将会用到。



这个是TensorFlow官方教程《深入MNIST》中的完整代码。完整教程在这里。 

import tensorflow as tf

#导入input_data用于自动下载和安装MNIST数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

#创建一个交互式Session
sess = tf.InteractiveSession()

#创建两个占位符,x为输入网络的图像,y_为输入网络的图像类别
x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])

#权重初始化函数
def weight_variable(shape):
    #输出服从截尾正态分布的随机值
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

#偏置初始化函数
def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

#创建卷积op
#x 是一个4维张量,shape为[batch,height,width,channels]
#卷积核移动步长为1。填充类型为SAME,可以不丢弃任何像素点
def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding="SAME")

#创建池化op
#采用最大池化,也就是取窗口中的最大值作为结果
#x 是一个4维张量,shape为[batch,height,width,channels]
#ksize表示pool窗口大小为2x2,也就是高2,宽2
#strides,表示在height和width维度上的步长都为2
def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1,2,2,1],
                          strides=[1,2,2,1], padding="SAME")


#第1层,卷积层
#初始化W为[5,5,1,32]的张量,表示卷积核大小为5*5,第一层网络的输入和输出神经元个数分别为1和32
W_conv1 = weight_variable([5,5,1,32])
#初始化b为[32],即输出大小
b_conv1 = bias_variable([32])

#把输入x(二维张量,shape为[batch, 784])变成4d的x_image,x_image的shape应该是[batch,28,28,1]
#-1表示自动推测这个维度的size
x_image = tf.reshape(x, [-1,28,28,1])

#把x_image和权重进行卷积,加上偏置项,然后应用ReLU激活函数,最后进行max_pooling
#h_pool1的输出即为第一层网络输出,shape为[batch,14,14,1]
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

#第2层,卷积层
#卷积核大小依然是5*5,这层的输入和输出神经元个数为32和64
W_conv2 = weight_variable([5,5,32,64])
b_conv2 = weight_variable([64])

#h_pool2即为第二层网络输出,shape为[batch,7,7,1]
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

#第3层, 全连接层
#这层是拥有1024个神经元的全连接层
#W的第1维size为7*7*64,7*7是h_pool2输出的size,64是第2层输出神经元个数
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])

#计算前需要把第2层的输出reshape成[batch, 7*7*64]的张量
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

#Dropout层
#为了减少过拟合,在输出层前加入dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

#输出层
#最后,添加一个softmax层
#可以理解为另一个全连接层,只不过输出时使用softmax将网络输出值转换成了概率
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])

y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

#预测值和真实值之间的交叉墒
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))

#train op, 使用ADAM优化器来做梯度下降。学习率为0.0001
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

#评估模型,tf.argmax能给出某个tensor对象在某一维上数据最大值的索引。
#因为标签是由0,1组成了one-hot vector,返回的索引就是数值为1的位置
correct_predict = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))

#计算正确预测项的比例,因为tf.equal返回的是布尔值,
#使用tf.cast把布尔值转换成浮点数,然后用tf.reduce_mean求平均值
accuracy = tf.reduce_mean(tf.cast(correct_predict, "float"))

#初始化变量
sess.run(tf.initialize_all_variables())

#开始训练模型,循环20000次,每次随机从训练集中抓取50幅图像
for i in range(20000):
    batch = mnist.train.next_batch(50)
    if i%100 == 0:
        #每100次输出一次日志
        train_accuracy = accuracy.eval(feed_dict={
            x:batch[0], y_:batch[1], keep_prob:1.0})
        print ("step %d, training accuracy %g" % (i, train_accuracy))

    train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5})

print ("test accuracy %g" % accuracy.eval(feed_dict={
    x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))


实验过程:

因为单机性能有限,所以程序运行时间比较长,我们发现,越到后面,准确率越高





参考地址:http://www.tensorfly.cn/tfdoc/tutorials/mnist_pros.html

MNIST 数据下载http://www.tensorfly.cn/tfdoc/tutorials/mnist_download.html




本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习之手写数字识别 的相关文章

  • 【从零开始学c++】——string

    学好STL 一 STL简介 了解 1 什么是STL 2 STL的六大组件 3 STL的缺陷 2 string 1 string的简单了解 如何对stl的查阅 2 string常用接口说明 1 string类 对象常见的构造 2 string
  • Kotlin入门学习(非常详细),从零基础入门到精通,看完这一篇就够了

    文章目录 kotlin的历史 Kotlin的工作原理 语言类型 编译型 解释型 Java的语言类型 Kotlin的运行原理 创建Kotlin项目 语法 变量 变量的声明 基本类型 var和val的本质区别 函数 函数的声明 声明技巧 函数的
  • 找准边界,吃定安全

    创新的资源管理算法 基于会话的全分布式处理流程 山石网科全分布式架构 打破了传统架构的限制 找准边界 吃定安全 往期文章 从访问控制谈起 再看零信任模型 威胁情报加持 泛边界下的全局主动防御体系如何着手 随着 2019 年我国以信息网络等新
  • 连 连 看

    1 案例介绍 连连看是一款曾经非常流行的小游戏 游戏规则 点击选中两个相同的方块 两个选中的方块之间连接线的折点不超过两个 接线由X轴和Y轴的平行线组成 每找出一对 它们就会自动消失 连线不能从尚未消失的图案上经过 把所有的图案全部消除即可
  • C/C++之01背包问题

    问题描述 给定N个物品 每个物品有一个重量W和一个价值V 你有一个能装M重量的背包 问怎么装使得所装价值最大 每个物品只有一个 输入格式 输入的第一行包含两个整数n m 分别表示物品的个数和背包能装重量 以后N行每行两个数Wi和Vi 表示物

随机推荐

  • <form>表单

    1 form表单
  • osgEarth的Rex引擎原理分析(三十六)为什么要删除设置过的垂直水准面

    目标 二十九 中的问题86 椭球体 水平面 应该不是删除 而是信息创建出一个没有垂直水准面的Profile 待继续分析列表 9 earth文件中都有哪些options 九 中问题 10 如何根据earth文件options创建不同的地理信息
  • java gc 次数_浅谈如何减少GC的次数

    GC会stop the world 会暂停程序的执行 带来延迟的代价 所以在开发中 我们不希望GC的次数过多 本文将讨论如何在开发中改善各种细节 从而减少GC的次数 1 对象不用时最好显式置为 Null 一般而言 为 Null 的对象都会被
  • 应用程序无法正常启动0xc000007b请点击确定关闭应用程序

    应用程序无法正常启动0xc000007b怎么办 这是很多用户在电脑的使用过程中会出现的一个问题 究竟出现这个问题的时候 我们要怎么去解决它 让我们的电脑重新恢复正常使用呢 想要解决这个问题就一起来看看0xc000007b错误解决办法吧 0x
  • nRF52832学习记录(一、外设初识之 GPIOTE)

    添加GPIO和GPIOTE寄存器表 对于应用的理解对着寄存器查看会比较明了 这个不管是在哪款芯片上都是如此 2021 9 27 这些年蓝牙5 0的应用越来越多 最近也是想着把以前Enocean的低功耗设备有过的产品 用蓝牙做一套匹配的版本
  • pikachu靶场的两道RCE

    第一道题 ping一个ip并查看当前目录 输入127 0 0 1 点击ping 出来一堆乱码 第一种方法 按win r键 输入regedit 点击确定 即打开注册表编辑器 打开HKEY CURRENT USER项 打开其中的Console项
  • 浅谈Python网络爬虫应对反爬虫的技术对抗

    在当今信息时代 数据是非常宝贵的资源 而作为一名专业的 Python 网络爬虫程序猿 在进行网页数据采集时经常会遭遇到各种针对爬虫行为的阻碍和限制 这就需要我们掌握一些应对反爬机制的技术手段 本文将从不同层面介绍如何使用 Python 进行
  • 概率论与数理统计学习笔记——第三十讲——方差定义和计算公式

    1 方差概念的引入 2 方差 标准差 均方差 的定义及计算公式 3 0 1分布的方差 4 泊松分布的方差 5 均匀分布的方差 6 指数分布的方差 7 方差的应用实例 投资方案评估
  • Kubernetes Configmap + Secret

    Secret是什么 在Kubernetes中 Secret是一种用于存储敏感信息的资源对象 它主要用于保存密码 API令牌 密钥和其他敏感数据 以供容器 Pod或集群中的其他资源使用 Secret有以下特点 安全存储 Secret对象被用于
  • Eclipse 搭建一个servlet小程序

    跳转 http www importnew com 14621 html Servlet 是一些遵从Java Servlet API的Java类 这些Java类可以响应请求 尽管Servlet可以响应任意类型的请求 但是它们使用最广泛的是响
  • C++中break与continue的用法

    根据break的用法 是在循环体内 强行结束循环的执行 也就是结束整个循环的过程 不再执行循环的条件是否成立 直接转向循环语句下面的语句 continue的作用 在循环语句中 跳出本次循环中余下尚未执行的语句 继续执行下一次循环 其包括两点
  • sqli-labs 1——20关攻略

    1 10 GET传输 Less 1联合查询 优点 查询方便 速度很快 缺点 必须要有显示位 1 判断sql语句中一共返回了多少列 order by 3 对比如下两张图的显示页面 得知有3列 2 查看显示位 union select 1 2
  • 你在用什么写用例

    这段时间用例评审项目组三个成员 有用excel的 有用xmind的 有用禅道的 而我关于用例用到xmind 后来用excel 后来用禅道一直到现在 xmind是思路分析和整理的工具 在最开始做测试的前3年可以说很依赖这款工具 后来 如果要做
  • QT的学习

    1 Test brower 文本浏览器 2 菜单栏窗体里面有预览功能 3 窗口的布局 4 信号与槽 其实就是时间处理函数 类的成员函数 2019 5 27 学习了 QFileDialog 类 就是选择文件 并且把文件名显示到line edi
  • 编写软件测试文档实验报告,黑盒测试软件测试实验报告.doc

    黑盒测试软件测试实验报告 doc 软件测试与质量课程实验报告 实验2 实验2 黑盒测试法实验 姓名院系 学号 任课教师 实验指导教师 实验地点 实验吋间 实验目的 系统地学习和理解黑盒测试的 本概念 原理 熟悉和掌握等价类划分法 边界值分析
  • bootstrap 动态添加js 页面渲染_给Shopify页面添加动态背景特效教程(傻瓜式操作模板)...

    第一种特效 多彩动态气泡向中心焦点聚合js动画 操作 复制代码如下代码 然后打开页面 切换到添加代码模式 然后复制到内容的最顶部 如下图所示
  • 【STM32技巧】使用STM32 HAL库的硬件I2C驱动RX8025T实时时钟芯片

    基础配置 使用单片机APM32F103RBT6 使用外设I2C1 PB7 SDA 使用外设I2C1 PB6 SCK STM32CUBEMX 版本5 6 配置如下 i2c c文件 File Name I2C c Description Thi
  • 目标检测中图片预处理之图片大小分析

    前言 很多做目标检测的新手 拿到数据集就迫不及待想找一个算法来跑它 内心先爽一把 包括我在内也是这样 其实样的做法不合理 我们应该先对数据集进行一些分析 找出数据集的特点 有针对性的进行检测 首先要关注的是图片大小 这个相当重要 假设测试文
  • 01Nginx源码分析之初探Nginx架构

    01Nginx源码分析之初探Nginx架构 注 接下来的源码分析我都是参考以下这位博主的 但是有些地方不对的我会修改 毕竟每个人理解不一样 并且版本为nginx stable 1 18 自娱自乐的代码人 1 初探Nginx架构 第一篇没什么
  • 深度学习之手写数字识别

    当我们开始学习编程的时候 第一件事往往是学习打印 Hello World 就好比编程入门有Hello World 机器学习入门有MNIST MNIST是一个入门级的计算机视觉数据集 它包含各种手写数字图片 它也包含每一张图片对应的标签 告诉