Tensorflow(二)MNIST数据集分类

2023-10-29

1.获取数据集

有两种方式可以得到数据集,第一是直接通过mnist = input_data.read_data_sets('MNIST_data',one_hot = True)进行联网下载,但这个方法可能很慢或者连接不到服务器,所以推荐使用第二个,在MNIST 直接下载数据,然后放在当前路径下的‘MNIST_data’文件夹中,下载以后不需要解压,直接压缩包放进去就好,它会自动解压获取数据的。

这个数据集是包含了0-9这10个数字的很多图片,我们要做的就是给出一张图片上面有数字,分析出这个数字是0-9中的哪一个。

one-hot的意思就是,输出值y应该是一个这样的形式【0,1,2,3,4,5,6,7,8,9】,每一个位置对应一个数字,得到的预测值只有一位为1,其他都是0,也就是说,如果预测值是【1,0,0,0,0,0,0,0,0,0】代表这个数字为0。预测值和给定的y值都是这种形式

 

2.softmax

softmax是一个激活函数,用在多分类问题的最后一层,最后一层使用这个函数以后得到的yhat并不是最终的预测值,以数据集分类例子来说,他是由10个小数组成的向量,并且这10个数的和为1,这10个数分别代表了最后这张图片是哪个数字的几率。选择最大的一项作为1,其他项都为0,最终得到预测值。

 

3.关于数据维度的问题

上面在讲向量的时候我都说的比较模糊,没有说明到底是行向量还是列向量,在之前吴恩达老师深度学习的课程中,神经网络中各层数据和参数的维度是这样定义的:(这是简化示例图,为了方便说明w和b的维度,下面代码展示的并没有隐藏层)

但这个数据集获取到的数据定义形式,也就是X和Y正好是反着来的,如果想要按照这个模型来,把获取到的数据转置即可,但这里为了以后增加改进方便,用了数据原始的格式。

注意数据的格式会引起两个问题,第一是参数值的维度,第二是向前传播是矩阵乘法的顺序

 

 

4.代码

 

首先导入需要用到的包:

import tensorflow as tf
#mnist是tensorflow中一个实例,使用input_data来下载/引入数据
from tensorflow.examples.tutorials.mnist import input_data

 

得到并且分析数据:

#获取所有的数据,包括train-set test-set
mnist = input_data.read_data_sets('MNIST_data',one_hot = True)

执行这条命令就会得到:

Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz

它会自动解压压缩包,得到里面的数据,我们可以分析一下里面数据的格式

#训练数据的样本个数
mnist.train.num_examples  #55000

#看看训练集x和y的维度,其中images为X ,labels为Y
print(mnist.train.images.shape)   #(55000, 784)
print(mnist.train.labels.shape)   #(55000, 10)

#测试集
print(mnist.test.images.shape)   #(10000, 784)
print(mnist.test.labels.shape)   #(10000, 10)

 

当数据很多的时候,可以使用Mini-batch gradient descent来训练模型,也就是分批次的训练数据,每一次迭代将数据分别n_batch个组,每个组有batch_size个数据,用batch_size个数据对模型进行训练。

#定义批次和一共有多少批次
batch_size = 100
n_batch = mnist.train.num_examples // batch_size

 

使用tensorflow的一大好处就是,它只需要实现向前传播和计算cost值,向后传播和更新参数它会自动帮你完成。

#定义两个占位符
#占位符就是,它并没有真实的数据,但给了下面代码使用数据的机会,等在session中再通过feed-dict来把数据喂给模型
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])


#向前传播
w = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([1,10]))
prediction = tf.nn.softmax(tf.matmul(x,w)+b)

#计算代价值
cost = tf.reduce_mean(tf.square(y-prediction))

#以梯度下降的方式,目标是减少cost值训练一个神经网络 学习因子为0.2
train = tf.train.GradientDescentOptimizer(0.2).minimize(cost)

#预测计算准确度
#equal比较两个参数大小是否一样,一样返回true,不一样是false ,得到的其实是true和false的向量
#argmax求y中最大值在哪个位置,1表示对哪个维度找最大值,1表示对应10的那一维
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))

#把预测值转化为浮点 ,true变成1.0 false变成0.0 ,再求平均值,
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))


#开始训练
with tf.Session() as sess:
    #前面定义的变量,就需要先初始化变量
    sess.run(tf.global_variables_initializer())

    for i in range(100): #迭代100次
        for batch in range(n_batch):  #分批次迭代
            #获取本批次的数据
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train,feed_dict={x:batch_xs,y:batch_ys})
        
        #每迭代一次使用test集测试一下准确度
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print('after ',i, 'the accuracy is ',acc)

 

after  0 the accuracy is  0.8299
after  10 the accuracy is  0.9061
after  20 the accuracy is  0.9133
after  31 the accuracy is  0.9188
after  40 the accuracy is  0.9197
after  50 the accuracy is  0.9212
after  60 the accuracy is  0.9234
after  70 the accuracy is  0.9241
after  80 the accuracy is  0.9243
after  90 the accuracy is  0.9257
after  99 the accuracy is  0.9258

 

5.后记

对于第一次接触tensorflow的我,这个过程还是显得很神奇的,感觉不用自己实现反向传播,工作量就少了一大半,而且它的先定义placeholder然后再feed data的机制真的灰常巧妙!!很有种定义函数中参数的感觉。tensorflow中的变量定义了以后,其实是没有值得,比如输入w得到的是 : <tf.Variable 'Variable_8:0' shape=(784, 10) dtype=float32_ref>

并不是它初始化的值,在session中run过以后,才会赋予它我们所期望的值。

这里没有设置隐藏层,隐藏层的变化也是针对向前传播的,其他的其实都和这个一样。

 

 

关于代价函数:

在这里用到的代价函数是二次代价函数,也就是用真实值-预测值取平方,使用交叉熵的代价函数其实更好,也可以提升准确率,

交叉熵的代价函数就是 y*log(y-hat),真实值乘以预测值求log,因为真实值其实是one-hot的形式,所以除了一项为1,其他项都是0,那么这个代价函数也就是 y为1的位置,对应y-hat相同位置的的概率值求对数。

上面两个代价函数都要取负数,而且说的只是当个样本,对于多个样本的情况需要把所有代价相加然后取均值

tf中也有内置函数可以直接实现交叉熵的计算:

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = y,logits = prediction))

 

 

 

使用drop_out:

#drop_our初始化参数 为1.0代表所有的神经元都是工作的
#stddev表示数据的方差为0.1
w1 = tf.Variable(tf.truncated_normal([784,2000],stddev = 0.1))
b1 = tf.Variable(tf.zeros([1,2000])+0.1)
A1 = tf.nn.tanh(tf.matmul(x,w1)+b1)
A1_drop = tf.nn.dropout(A1,keep_prob)


w2 = tf.Variable(tf.truncated_normal([2000,20],stddev = 0.1))
b2 = tf.Variable(tf.zeros([1,20])+0.1)
A2 = tf.nn.tanh(tf.matmul(A1_drop,w2)+b2)
A2_drop = tf.nn.dropout(A2,keep_prob)

w3 = tf.Variable(tf.truncated_normal([20,10],stddev = 0.1))
b3 = tf.Variable(tf.zeros([1,10])+0.1)
A3 = tf.nn.tanh(tf.matmul(A2_drop,w3)+b3)
A3_drop = tf.nn.dropout(A3,keep_prob)

w4 = tf.Variable(tf.truncated_normal([10,10],stddev = 0.1))
b4 = tf.Variable(tf.zeros([1,10])+0.1)
prediction = tf.nn.softmax(tf.matmul(A3_drop,w4)+b4)

 

在session中训练迭代的时候添加keep_drop参数:

sess.run(train,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.8})

 

预测训练集和测试集的时候要关闭drop out:

 train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})
 test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow(二)MNIST数据集分类 的相关文章

随机推荐

  • 线性代数的本质(干货!)

    原文链接 https www cnblogs com TenosDoIt p 3214096 html 从大学开始接触矩阵论和线性代数 记了很多公式 但是总感觉徘徊在线性代数的门外没有进去 感觉并没有接触到它的核心概念 不巧看到了这篇博客
  • 7.Unity中c#代码学习(物理系统刚体+碰撞检测(爆炸效果实现))

    刚体 通过添加组件Physics Rigidbody 实现对物体插入物理引擎 刚体 碰撞体 查看碰撞体范围 可以编辑碰撞的范围 碰撞体 在文件中创建物理材质 右键 create Pythsics material friction摩擦力 有
  • 建站平台(WebPlus)申请建站流程图及相关使用文献

    WebPlus系统是学校信息网络中心提供的用于建设部门网站的管理平台 可实现快速建站和校内信息资源共享平台 每个独立部门原则上只能申请一个WebPlus建站空间 平台使用方法请访问 http service webplus net cn 上
  • (转)CASE WHEN 用法

    Case具有两种格式 简单Case函数和Case搜索函数 简单Case函数 CASE sex WHEN 1 THEN 男 WHEN 2 THEN 女 ELSE 其他 END Case搜索函数 CASE WHEN sex 1 THEN 男 W
  • vue脚手架搭建、介绍和初始页面的构造(图文详细)

    文章目录 什么是vue脚手架 前置环境的安装 配置node js 安装脚手架vue cli 创建项目 项目配置 项目结构 修改初始页面 样式的less语法 什么是vue脚手架 Vue脚手架 Vue CLI 是一个官方提供的命令行工具 用于快
  • 【文件上传】绕过总结

    一般绕过会分为黑名单绕过 白名单绕过 特殊类型绕过 以下为文件上传后缀绕过 黑名单绕过 1 大小写绕过 eg a JSP a Jsp a jsP a jSP等等 2 空格绕过 一般保存文件名前后带空格 保存时都会被忽略掉 而php在传输中
  • linux_fasync的总结

    fasync的总结 我们知道 驱动程序运行在内核空间中 应用程序运行在用户空间中 两者是不能直接通信的 但在实际应用中 在设备已经准备好的时 候 我们希望通知用户程序设备已经ok 用户程序可以读取了 这样应用程序就不需要一直查询该设备 的状
  • window系统配置PCL的简化方法(不需要复制一百多个依赖项目名称,直接导入配置表)

    1 下载文件 百度网盘 链接 https pan baidu com s 1WQQ8kaDilaagjoK5IrYZzA 提取码 1111 注意 直接解压在E盘 不解压在E盘也可以 后续替换环境变量和属性表文件内的地址就行 props文件
  • 【独家发布】行业深度报告:《风口上的半导体

    作为关乎国民经济和国家安全的战略型行业 半导体行业在我国占据重要地位 尤其在美国对我国半导体核心产品和零部件实行技术封锁的大背景下 国产芯片亟需实现独立自主并获得长足发展 一场全产业链国产化替代风潮正愈演愈烈 与此同时 半导体行业也收获了诸
  • 使用Python,OpenCV的Meanshift 和 Camshift 算法来查找和跟踪视频中的对象

    使用Python OpenCV的Meanshift 和 Camshift 算法来查找和跟踪视频中的对象 1 效果图 2 源码 2 1 MeanShift 2 2 Camshift Continuously Adaptive Meanshif
  • vs2019登录提示“我们无法刷新此账户的凭证”

    打开代理服务器设置 查看自动设置代理与手动设置代理的开关有没有被自动打开 如果有的话把它关掉 就能正常登录了
  • D - 整数变换问题

    整数变换问题 题意 问我们最少经过多少次变换可以将n转化为m 题解 这个题我们很容易想到就是用dfs 但是数据范围也很明显不能用直接的暴力 所以我们需要剪枝 我们假设用最原始的暴力 就是每次循环两种情况一直到最后 这样的暴力很机械 很盲目
  • 华为OD机试真题-选修课-2023年OD统一考试(B卷)

    题目描述 现有两门选修课 每门选修课都有一部分学生选修 每个学生都有选修课的成绩 需要你找出同时选修了两门选修课的学生 先按照班级进行划分 班级编号小的先输出 每个班级按照两门选修课成绩和的降序排序 成绩相同时按照学生的学号升序排序 输入描
  • 百度地图,如何成为智能化位置服务平台

    深几度 产业数字化 撰稿 吴俊宇 编辑 吴俊宇 审阅 梁欣婷 摘要 对行业而言 百度地图在当下的角色转变具备代表性意义 这是产业数字化浪潮下的一次成功转型 在过去移动生态下诞生的产品 在今天都值得深入挖掘其中的数据价值 这些价值可以延展至国
  • 安卓期末大作业-图书馆借书系统、图书借阅app(附下载链接)

    安卓期末大作业 图书馆借书系统 借书APP 可以注册登录 保存数据记录 含源码和导出app 运行截图 安卓期末大作业 图书借阅APP 老师给了95分 可以注册登录 借阅书籍 还书 含数据库存储借书记录 导入AndroidStudio即可使用
  • 信标链:以太坊2.0的新起点

    原创 市后诸葛 虽然以太坊2 0依旧用 以太坊 命名 但以太坊1 0和以太坊2 0其实是完全不同的两种架构 以太坊1 0和2 0的差别 远不是POW和POS的区别 在以太坊2 0里面 基础链就是 信标链 在真正的以太坊2 0里面 是只有po
  • @With,@Accessors(chanins=true),@ExtensionMethod——Lombok常用注解

    目录 一 With 很少用 二 Accessors 非常好用 一 fluent 布尔型 二 chain 布尔型 三 ExtensionMethod 实验阶段 一 With 很少用 这个注解可以用在类上也可以用在单个的成员变量上 使lombo
  • C++中的指针概念梳理

    在C 中指针通常难以理解 即使是有经验的程序员也常常因为调试指针引发的错误而备受折磨 笔者在学习C 时常常被指针弄得晕头转向 于是决定对指针的概念做一次梳理 希望本文能够对C 入门者有些许作用 1 指针的概念 指针 pointer 是 指向
  • Electron 实现切换暗_亮模式与主题

    文章末尾附上仓库地址 清单 模板基于 electron vite vue vue3 ts vite 组件库 element plus hooks库 vueuse useElementPlusTheme 初始化工程 使用 electron v
  • Tensorflow(二)MNIST数据集分类

    1 获取数据集 有两种方式可以得到数据集 第一是直接通过mnist input data read data sets MNIST data one hot True 进行联网下载 但这个方法可能很慢或者连接不到服务器 所以推荐使用第二个