深度学习入坑笔记之二---手写体图像识别问题

2023-11-19

目录

前言

写在前面的话:

1本文是根据tensorflow中文社区的相关内容进行整理及适当的解释重新梳理而成,对于初学者更加友好,本文仅关注达成目标的方法,对于实现方法的具体原因本文不做深究,会在其他博客里面另行解释。

2本文数据来源为http://yann.lecun.com/exdb/mnist/数据集

3写这篇博文的日期是2019年10月1日,祝愿祖国繁荣。

MNIST 手写体图像识别是机器学习领域最基础也是最经典的案例,相当于语言学习中的‘HELLO WORLD’。他包含了各种手写体数字图片:
在这里插入图片描述
该数据集包含了0-9总共10个类别的数字图片(标签),及每个图片对应的标签(让计算机知道3是3,5是5之类)本文的上半部分主要介绍通过Softmax Regression这个简单的数学模型进行图像的预测;本文后半段则利用卷积神经网络对相同的数据集进行预测,以对比两种不同的建模思路对预测结果的影响。

通过softmax进行手写体图像建模及识别

通过softmax regresion做数学建模并识别图像的一般流程为:1,数据的输入;2,建立模型;3,训练模型;4,评估模型。

数据导入

这里数据用的是MNIST官网给出的数据,在导入数据前,我们先把tensorflow模块导入,具体代码如下:

import warnings
warnings.filterwarnings('ignore') #忽略掉运行过程中出现的警告提示
#导入相关模块
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data

#下载数据集
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
print(tf.__version__)
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
1.14.0         

接下来对下载的数据集进行检查,具体代码如下:

#检查数据集
print(mnist.train.images.shape, mnist.train.labels.shape)#打印训练数据集
print("------------------------------------------------------------------------------------------")
print(mnist.test.images.shape, mnist.test.labels.shape)#打印测试数据集
print("------------------------------------------------------------------------------------------")
print(mnist.validation.images.shape, mnist.validation.labels.shape)#打印验证数据集
(55000, 784) (55000, 10)
----------------------------------------------------------------
(10000, 784) (10000, 10)
 ----------------------------------------------------------------
(5000, 784) (5000, 10)

从结果得知,下载下来的数据集被分成两部分:60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)。其中训练数据又分为55000个训练数据和5000个验证数据。
每一张图片均包含了28像素x28像素,用数组表示图像为长度为784的张量,如下图所示:
在这里插入图片描述

softmax建模

导入数据之后,开始进行数学建模。第一次我们利用softmax回归建立一个简单模型具体代码如下:

#该数学模型的数学结构为y=Wx+b
x = tf.placeholder('float', [None, 784])#x不是一个特定的值,而是一个占位符placeholder,关于占位符,我们会在另外的文章中详谈,同时网上也有很多详细介绍
#上面的None表示此张量的第一个维度可以是任意长度
W = tf.Variable(tf.zeros([784, 10]))#W代表权重
b = tf.Variable(tf.zeros([10]))#b代表偏置项
y = tf.nn.softmax(tf.matmul(x,W) + b)

注意,W的维度是[784,10],因为我们想要用784维的图片向量乘以它以得到一个10维的证据值向量,每一位对应不同数字类。b的形状是[10],所以我们可以直接把它加到输出上面。

训练模型

训练模型之前,我们首先要定义一个指标,用这个指标来判断最终的模型输出结果是好是坏。在机器学习中比较常用的做法是我们定义一个损失函数(loss function)/代价函数(cost function),当这个函数的结果越小时,我们认为模型的模拟程度越好。
本示例中,我们的成本函数采用‘交叉熵’(cross-entropy),关于交叉熵的具体含义,本文不做赘述。具体的代码如下所示:

#定义损失函数,判断模型好坏
y_ = tf.placeholder('float', [None, 10]) #定义一个新的占位符用于输入正确的值(即标签)
cross_entropy = -tf.reduce_sum(y_*tf.log(y))#定义交叉熵,关于交叉熵的具体含义及用法,我会在另外的文章中详细介绍
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #选择梯度下降优化器,并将学习率设为0.01
init = tf.initialize_all_variables() #初始化变量,这句话也可以写为tf.global_variables_initializer替代
sess = tf.Session() #运行对话,开启模型
sess.run(init)
#开始训练模型,循环设为1000次
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100) #以100作为一个训练批次进行训练
sess.run(train_step, feed_dict = {x: batch_xs, y_:batch_ys}) #这里指将训练数据放进x的占位符,将标签放进y_的占位符, y是预测值,靠计算得出

我们对模型进行训练,循坏设为1000,分为10次完成。

模型评估

当我们对数据训练完毕之后,我们需要测试,我们训练好的模型是否准确,这时候就需要对模型进行一个评估。这里我们对测试数据进行评估。具体的代码如下所示:

#评估模型
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))#这行代码的目的是对比预测值y与标签y_是否匹配
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float')) #这行代码会给我们一组布尔值。为了确定正确预测项的比例,我们可以把布尔值转换成浮点数,然后取平均值。例如,[True, False, True, True] 会变成 [1,0,1,1] ,取平均值后得到 0.75.
print (sess.run(accuracy, feed_dict = {x: mnist.test.images, y_: mnist.test.labels}))#评估模型准确率
0.9147

最终结果得到的模型精度大约为91%。接下来我们会对模型进行一个简单的优化,引入卷积神经网络,然后比较卷积神经网络得到的模型精度是多少。

通过卷积网络进行手写体图像建模及识别

卷积神经网络模型需要的权重和偏置项数目数量远大于softmax模型,同时为了避免权重出现0梯度,我们应该加入适量的噪声,防止权重出现对称情况。我们使用的是ReLU神经元,因此比较好的做法是用一个较小的正数来初始化偏置项,以避免神经元节点输出恒为0的问题(dead neurons)。具体的代码实现过程如下:

初始化权重
#重新构建一个卷积神经网络,预测同样的数据集并进行比较
# 定义权重和偏置项,该做法的具体意义我们以后另讲,这里不再赘述
def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev = 0.1)
    return tf.Variable(initial)
def bias_variable(shape):
    initial = tf.constant(0.1, shape = shape)
    return tf.Variable(initial)
定义卷积层及池化层

本示例中,我们设定卷积步长为1,边距填充为0,池化层采用最大池化,尺寸为2x2,具体代码实现如下:

#卷积和池化处理
def conv2d(x, W): #定义卷积层
    return tf.nn.conv2d(x, W, strides = [1, 1, 1, 1], padding = 'SAME') #步长设为1,边距填充为0

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')
添加层

定义完毕卷积层及池化层之后,接下来就要将卷积层和池化层逐层添加进模型之中,具体代码及解释如下:

#第一层卷积
#第一层的结构包括一个卷积层加一个最大池化层。
W_conv1 = weight_variable([5, 5, 1, 32])# 前两个维度代表patch大小,1代表通道数目,32是输出的通道数目
b_conv1 = bias_variable([32])#对应上面每一个输出的通道
x_image = tf.reshape(x, [-1,28,28,1])#x的维度应该和W对应,其中第2、3维对应图片的宽和高,最后嗲表颜色通道数
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)#把x和全职进行卷积,再加上偏置项,应用RELU激活函数防止线性化
h_pool1 = max_pool_2x2(h_conv1)#添加池化层

#第二层卷积层
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

#密集连接层
W_fc1 = weight_variable([7 * 7 * 64, 1024])#图片尺寸由28减少到了7,原因是经历了两次2x2的最大池化
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

#dropout
#这一层的目的是防止模型过拟合,过拟合的模型会影响泛化能力
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

#输出层
#卷积神经网络的最后输出层依然采取全连接的形式
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
训练及评估模型

由于模型教之前复杂,数据量较大,我们采取了ADAM优化器进行梯度下降。

#训练和评估模型
sess = tf.InteractiveSession()#这个一定要添加,否则会话无法计算
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.initialize_all_variables())
for i in range(20000):
  batch = mnist.train.next_batch(50)
    if i%100 == 0:
     train_accuracy = accuracy.eval(feed_dict={
     x:batch[0], y_: batch[1], keep_prob: 1.0})
     print('step %d, training accuracy %g'%(i, train_accuracy))
    train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
print("test accuracy %g"%accuracy.eval(feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
step 0, training accuracy 0.1
step 100, training accuracy 0.76
step 200, training accuracy 0.92
step 300, training accuracy 0.92
step 400, training accuracy 0.9
......
step 19400, training accuracy 1
step 19500, training accuracy 1
step 19600, training accuracy 1
step 19700, training accuracy 1
step 19800, training accuracy 1
step 19900, training accuracy 1

test accuracy 0.9916

从最终的测试结果得出,使用了卷积神经网络的模型在手写体图像识别问题上的识别率可以提高到99.2%

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习入坑笔记之二---手写体图像识别问题 的相关文章

  • 各操作系统下安装docker

    1 查看服务器软硬件信息 1 1 判断操作系统类型 操作系统 基于发行版 统信UOS Debian 银河麒麟 StartOS Debian openEuler CentOS 优麒麟 Ubuntu Kylin Ubuntu 中标麒麟 Kyli
  • Java算法题:两数之和

    LeetCode原题 给你一个下标从 1 开始的整数数组 numbers 该数组已按 非递减顺序排列 请你从数组中找出满足相加之和等于目标数 target 的两个数 示例 1 输入 numbers 2 7 11 15 target 9 输出
  • STM32F103VG使用RTT实现发送DMX512调光数据

    DMX512调光协议和DALI一样属于数字调光协议 一个完整的DMX512数据包格式 1break 1mab 1startcode 512个调光数据 DMX512发送是基于485串口的基础上实现的特殊的数据协议 使用RTT需要把串口打开并且
  • 大话数据结构:线性表(顺序存储结构)

    线性表 零个或多个数据元素的有限序列 直接前驱元素 直接后继元素 线性表的长度 线性表元素的个数n 线性表的抽象数据类型 ADT线性表 list Data 线性表的数据对象集合为 a1 a2 an 每个元素的类型均为Datatype 其中
  • 微软服务器的主要功能,数据库服务器主要功能

    数据库服务器主要功能 内容精选 换一换 HANA全称High performanceAnalyticAppliance是由SAP开发的基于内存的面向行 列存储的关系型数据库管理系统 其作为数据库服务器的主要功能是根据应用程序的要求存储和检索
  • jdk17下载

    官网下载 https download oracle com java 17 latest jdk 17 windows x64 bin zip
  • 也想做一个绝地求生版的汽车控制移动,进来瞧瞧?(干货满满)

    控制车子移动 效果图附上 1 首先4个车轮复制一遍为车轮2备用 2 给车轮2全部添加wheel collider 只剩下车轮碰撞器和transform组件 3 给原版4个车轮添加脚本wheel 变量共有 面板赋值 依次添加车轮2里面的车轮c
  • c#图解教程和c#高级编程电子书链接

    链接 https pan baidu com s 1y TM08JvyBh8kQ0v7uT5hg 提取码 b0cq
  • Python的多维空数组赋值

    Python里面的list tuple默认都是一维的 创建二维数组或者多维数组也是比较简单 可以这样 list1 1 2 list1 append 3 4 可以这样 list2 1 2 3 4 还可以这样 list3 1 2 list3 i
  • android界面监控,防劫持

    1 首先要对自己应用的activity建立一个白名单 2 权限
  • http协议从客户端提交数据给服务器并返回数据

    老罗视频学习 本例从客户端提交数据给服务器 服务器接收到数据之后 看是否匹配 匹配返回字符串 login is success 失败返回 login is error 一 客户端 初始化url地址 private static String
  • Git如何比较不同分支的差异

    前两天 良许在做集成的时候碰到了一件闹心事 事情是这样的 良许的一位同事不小心把一个错误的 dev 分支 merge 到了 master 分支上 导致了良许编译不通过 于是 我们需要将版本回退到 merge 之前的状态 如果是下面这个状态
  • 电子设计竞赛(三)-SPWM与PID

    1 SPWM波调制技术 逆变电路的控制方式主要是采用SPWM 正弦脉宽调制技术 IR2104控制开关管的通断来实现正弦调制 SPWM的基本思路是将一个正弦波按等宽间距分成N等份 对于每一个波形以一个等面积的脉冲来对应 使脉冲的中点与相应正弦
  • python3 hashlib库sha256、pbkdf2_hmac、blake2b基本用法

    hashlib sha256 import hashlib x hashlib sha256 x update b asd print x 1 x hexdigest x hashlib sha256 x update asd encode
  • 数据下载网站整理

    数据十分重要 如何找到理想的数据显得更重要了 这里记录自己经过网上查询到的数据 进行整理 如果侵权 请联系我删除 再次感谢网友大佬们提供的资料 1 中国气象站点数据 下载地址 https www resdc cn data aspx DAT
  • 递归算法中的时间复杂度分析

    对于一种算法的时间复杂度分析还是特别重要的 在一些非递归算法中 我们仅仅看运算次数最多的那一行代码可能执行多少次就可以 实际就是看在循环中变量的变化 但是对于递归算法中该怎么分析呢 下面介绍几种递归函数中的算法时间复杂度分析的方法 0 递推
  • 使用paramiko跨服务器传输文件/文件夹

    一些概念 SSH Secure Shell 安全外壳协议 是建立在应用层基础上的安全协议 专为远程登录和其他网络服务提供安全性的协议 SFTP SSH 文件传输协议 Secret File Transfer Protocol SFTP 安全
  • window.location.href的用法

    window location href的用法 一 前言 二 常见用例 一 前言 window location href 是一个用于获取当前页面 URL 或让浏览器跳转到新 URL 的重要方法 是 window location 对象的属
  • 【gis系列】等高线创建dem,以及高程分析,坡度分析,坡向分析

    绝对原创 首先 我们要整理一份cad的文件格式 这里我不说那么多 就是在某某地图下载后 方法很多 可以通过qgis globalmapper来操作数据 以及一些普通的地图软件直接生成 这里呢 然后进入cad 把里面的高程标注信息给删除掉 图

随机推荐

  • 机器学习资源大全

    C 计算机视觉 CCV 基于C语言 提供缓存 核心的机器视觉库 新颖的机器视觉库 OpenCV 它提供C C Python Java 以及 MATLAB接口 并支持Windows Linux Android and Mac OS操作系统 通
  • SD卡初始化以及命令详解

    SD卡是嵌入式设备中很常用的一种存储设备 体积小 容量大 通讯简单 电路简单所以受到很多设备厂商的欢迎 主要用来记录设备运行过程中的各种信息 以及程序的各种配置信息 很是方便 有这样几点是需要知道的 SD 卡是基于 flash 的存储卡 S
  • Visual Studio 创建DLL 、LIB及调用

    一 前言 在工程中 经常会根据不同的场景需求将类封装成库文件 以供他人使用 那么如何利用VS进行库 动态库 的生成呢 以下简要演示实现过程 开发环境 VS2019 二 生成DLL动态库 1 创建控制台工程 添加类库函数 2 添加函数代码 d
  • vue打包及运行白屏,Android低版本适配

    版本支持 对于Android 4 X无法打开的问题 具体表现 1 运行后低版本谷歌浏览器打开后白屏 2 打包后低版本Android系统打不开 白屏 打包前npm run build后低版本浏览器打开白屏 如果低版本打开白屏那么打包后低版本A
  • CUDA系列三:矩阵相乘

    本博文主要讲解下基于cuda的矩阵相乘 cuda特别擅长的就是矩阵乘法 而且也比较容易实现 通过矩阵乘法的实现 可以比较容易理解cuda的核心思想 网上也有很多基于cuda实现的矩阵乘法 但是感觉都不完成 要不就是有错 本文给出的代码都是经
  • C#学习记录(47)MSSQL数据库

    引言 微软数据库是针对中小型企业的关系型数据库 操作简单易上手 首先介绍下C NET的数据库 以 ActiveX 数据对象 ADO 为基础 以 XML 扩展标记语言 为格式传送和接收数据 C NET应用程序 lt gt ADO NET lt
  • 特征值和特征向量的几何和物理意义

    原文 http blog 163 com renguangqian 126 blog static 1624014002011711114526759 FUCk 相见很晚 如果大学期间遇到这样的文章 线代必须90分以上 特征值和特征向量的几
  • vsCode中live server插件的安装及使用

    live server 插件是用来干嘛的 本地开发常常需要搭建临时的服务 作用 1 模拟服务器的方式打开页面 2 代码改动后 会自动刷新页面 安装 使用 1 使用要求 要求项目文件夹 Demo 要单独出现在vscode侧边栏 以下两种都可以
  • 软件设计风格(干货)-架构师之路(九)

    一 软件架构风格概念 Architecture架构 体系结构 软件体系结构风格是 描述某一特定应用领域中 系统组织方式 的惯用模式 架构风格定义一个系统家族 即 一个架构的定义 一个词汇表和一组约束 词汇表包含 一些构建和连接类型 而一组约
  • 你工作效率低,可能是因为不会Python

    前言 你是不是感觉你的工作非常无聊 每天有大量的重复性的工作要做 比如在我的工作中 就有很多类似的动作 每天早上要看我们DevOps流水线跑出的结果 查看各个微服务中的重复代码率是多少 有没有增加 CleanCode中的各项指标怎么样 代码
  • 微信加拿大服务器,微信新功能,在加拿大也可以任意刷人民币了

    原标题 微信新功能 在加拿大也可以任意刷人民币了 2018 6 11 加币 人民币 4 877 加币 美金 0 757 近日 微信悄悄上线了一项新功能 这就是 亲属卡 什么是 亲属卡 简单来说 就 是 你消费 别人买单 这项功能对于我们身在
  • 2021-01-10

    RIP 协议 一 合理分配IP地址 二 配置IP地址 三 运行RIPV 2 例R1 四 配置缺省路由 五 RIPV2 认证 例R1 六 配置空接口路由 防环 例R1 七 全网可通
  • 成员变量与局部变量的区别有哪些

    成员变量是在类内部定义的变量 在类的任何方法中都可以直接使用 其作用域为整个类 成员变量有默认值 如果没有给定初始值 数值类型默认为0 布尔类型默认为false 对象类型默认为null 局部变量是在方法 代码块 循环等内部定义的变量 其作用
  • 【羊了个羊】Burp抓取IOS微信小程序数据包

    描述 最近 小游戏 羊了个羊 在朋友圈刷屏 网友纷纷表示 游戏开发者多少有个病要治 本文记录 如何使用Burp抓取ios微信小程序数据包 工具准备 Burp 苹果手机 wifi 设置记录 手机和电脑连接同一wifi burp设置新代理 手机
  • 人脸分割 人脸解析 源码推荐

    2021年 有预训练 resnet50 126m 测试代码 python face warping test py i 0 e rtnet50 decoder fcn n 11 d cuda 0 Command line arguments
  • html js c 代码大全,js常用汇总

    javascript 代码库JS函数修改html的元素内容 及修改属性内容 document getElementById aid innerHTML World document getElementById aid href http
  • CBAM——即插即用的注意力模块(附代码)

    论文 CBAM Convolutional Block Attention Module 代码 code 目录 前言 1 什么是CBAM 1 Channel attention module CAM 2 Spatial attention
  • hexo的美化——拓展篇

    基础知识 css样式 hexo themes next source css 是next主题的样式文件 决定主题的外观 hexo themes next source css main styl 汇总css文件夹中所有的样式 hexo th
  • 一段有意思的异步代码片段

    毫不夸张的说 下面的代码会有一半的人输出错误 上代码 async function getCount id return id let count 0 async function addCount num count await getC
  • 深度学习入坑笔记之二---手写体图像识别问题

    深度学习入坑笔记之二 手写体图像识别问题 目录 前言 通过softmax进行手写体图像建模及识别 数据导入 softmax建模 训练模型 模型评估 通过卷积网络进行手写体图像建模及识别 初始化权重 定义卷积层及池化层 添加层 训练及评估模型