Tensorflow之MNIST手写数字识别:分类问题(1)

2023-10-26

一、MNIST数据集读取

one hot 独热编码
独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符
优点: 1、将离散特征的取值扩展到了欧式空间,离散特征的某个取值就对应欧式空间的某个点
    2、机器学习算法中,特征之间距离的计算或相似度的常用计算方法都是基于欧式空间的
    3、将离散型特征使用one_hot编码,会让特征之间的距离计算更加合理

import tensorflow as tf
 #MNIST数据集读取
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)

###输出结果###
#若不成功可手动到相关网站下载之后添加到文件夹中
#Extracting MNIST_data/train-images-idx3-ubyte.gz
#Extracting MNIST_data/train-labels-idx1-ubyte.gz
#Extracting MNIST_data/t10k-images-idx3-ubyte.gz
#Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

二、了解MNIST手写数字识别数据集

#了解MNIST手写数字识别数据集
print('训练集 train 数量:',mnist.train.num_examples,
      ',验证集 validation 数量:',mnist.validation.num_examples,
      ',测试集 test 数量:',mnist.test.num_examples)

###输出结果###
#训练集 train 数量: 55000 ,验证集 validation 数量: 5000 ,测试集 test 数量: 10000
print(' train images shape:',mnist.train.images.shape,
      'labels shape:',mnist.train.labels.shape)
###输出### #train images shape: (55000, 784) labels shape: (55000, 10)
#28*28=784,10分类One Hot编码

三、可视化image

#可视化image
import matplotlib.pyplot as plt

def plot_image(image):
    plt.imshow(image.reshape(28,28),cmap='binary')
    plt.show()
plot_image(mnist.train.images[1])

#进一步了解reshape()
import numpy as np
int_array = np.array([i for i in range(64)])
print(int_array)

int_array.reshape(8,8)

#行优先,逐列排列
int_array.reshape(4,16)



plt.imshow(mnist.train.images[20000].reshape(14,56),cmap='binary')
plt.show()

四、数据读取

1.采用独热编码,标签数据内容并不是直接输出值,而是输出编码

#标签数据与独热编码,
#内容并不是直接输出值,而是输出编码
mnist.train.labels[1]

输出结果:
array([ 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])

 #非one_hot编码的标签值
    mnist_no_one_hot = input_data.read_data_sets("MNIST_data/",one_hot=False)
    print(mnist_no_one_hot.train.labels[0:10])      #onr_hot = False,直接返回值

输出结果:
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
[7 3 4 6 1 8 1 0 9 8]

2.读取验证集数据

#读取验证集数据
print('validation images:',mnist.validation.images.shape,'labels:',mnist.validation.labels.shape)    

输出:
validation images: (5000, 784) labels: (5000, 10)
3.读取测试机数据

#读取测试机数据
print('tast images:',mnist.test.images.shape,'labels:',mnist.test.labels.shape)

输出结果:
tast images: (10000, 784) labels: (10000, 10)

4.一次批量读取多条数据

#一次批量读取多条数据
batch_image_xs,batch_labels_ys = mnist.train.next_batch(batch_size=10)        #next_batch()实现内部会对数据集先做shuffle
print(mnist.train.labels[0:10])
print("\n")
print(batch_labels_ys)

5.argmax()用法

argmax返回的是最大数的索引

import numpy as np
np.array(mnist.train.labels[1])
np.argmax(mnist.train.labels[1])     #argmax返回的是最大数的索引
#argmax详解
arr1 = np.array([1,3,2,5,7,0])
arr2 = np.array([[1,2,3],[3,2,1],[4,7,2],[8,3,2]])
print("arr1=",arr1)
print("arr2=",arr2)

argmax_1 = tf.argmax(arr1)
argmax_20 = tf.argmax(arr2,0)      #指定第二个参数为0,按第一维(行)的元素取值,即同列的每一行取值   以行为基准,每列取最大值的下标
argmax_21 = tf.argmax(arr2,1)       #指定第二个参数为1,则第二维(列)的元素取值,即同行的每一列取值   以列为基准,每行取最大值的下标
argmax_22 = tf.argmax(arr2,-1)     #指定第二个参数为-1,则第最后维的元素取值

with tf.Session() as sess:
    print(argmax_1.eval())
    print(argmax_20.eval())
    print(argmax_21.eval())
    print(argmax_22.eval())

输出结果:
arr1= [1 3 2 5 7 0]
arr2= [[1 2 3]
[3 2 1]
[4 7 2]
[8 3 2]]
4
[3 2 0]
[2 0 1 0]
[2 0 1 0]

五、可视化

#定义可视化函数
import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,labels,prediction,index,num=10):  #参数: 图形列表,标签列表,预测值列表,从第index个开始显示,缺省一次显示10幅
    fig = plt.gcf()             #获取当前图表,Get Current Figure
    fig.set_size_inches(10,12)    #1英寸等于2.45cm
    if num > 25 :      #最多显示25个子图
        num = 25
    for i in range(0,num):
        ax = plt.subplot(5,5,i+1)   #获取当前要处理的子图
        ax.imshow(np.reshape(images[index],(28,28)), cmap = 'binary')              #显示第index个图像
        title = "labels="+str(np.argmax(labels[index]))              #构建该图上要显示的title信息
        if len(prediction)>0:
            title += ",predict="+str(prediction[index])
            
        ax.set_title(title,fontsize=10)    #显示图上的title信息
        ax.set_xticks([])           #不显示坐标轴
        ax.set_yticks([])
        index += 1
    plt.show()
#可视化预测结果
# plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,10,10)

plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,10,25)

六、评估与应用

#评估模型
#完成训练后,在测试集上评估模型的准确率
accu_test = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Test Accuracy:",accu_test)
#完成训练后,在验证集上评估模型的准确率
accu_validation = sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
print("Test Accuracy:",accu_validation)
#完成训练后,在训练集上评估模型的准确率
accu_train = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
print("Test Accuracy:",accu_train)
#应用模型
#在建立模型并进行训练后,若认为准确率可以接受,则可以使用此模型进行预测
#由于pred预测结果是one_hot编码格式,所以需要转换成0~9数字
prediction_result = sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images})

#查看预测结果中的前10项
prediction_result[0:10]

七、tf.random_normal()介绍

#tf.random_normal()介绍
norm = tf.random_normal([100])    #生成100个随机数
with tf.Session() as sess:
    norm_data = norm.eval()
print(norm_data[:10])

import matplotlib.pyplot as plt
plt.hist(norm_data)
plt.show()

输出结果:
[-1.20503342 -0.40912333 1.02314627 0.91239542 -0.44498116 1.46095467
1.71958613 -0.02297023 -0.04446657 -1.58943892]

———网易云课堂《深度学习应用开发Tensorflow实践》学习记录

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow之MNIST手写数字识别:分类问题(1) 的相关文章

  • Vue项目引入Echarts可视化图表库教程&踩坑

    Apache ECharts是一个基于 JavaScript 的开源可视化图表库 ECharts是一款基于JavaScript的数据可视化图表库 提供直观 生动 可交互 可个性化定制的数据可视化图表 ECharts最初由百度团队开源 并于2
  • AD如何设置单个元器件规则

    在我们使用AD进行PCB绘制时 有时候可以将两个元器件重叠放置 但这时由于规则设置的原因 导致DRC检检查报错 那么如何将这种报错解决呢 一 在英文模式下 按下T M 忽视掉这个检查 但这种方法一旦移动元器件又会显示报错 二 设置单个元器件
  • 负载均衡器ribbon和LoadBalancer

    负载均衡器ribbon和LoadBalancer 客户端负载均衡器 目前主流的负载方案分为以下两种 集中式负载均衡 在消费者和服务提供方中间使用独立的代理方式进行负载 有硬件的 比如 F5 也有软件的 比如 Nginx 客户端根据自己的请求
  • 测试key-value

    package com datacloudsec test collect import com alibaba fastjson JSON import com datacloudsec UEBAApplication import co
  • 【ERROR】本地计算机上的mysql服务启动停止后 Windows下mysql数据库恢复

    MySQL突然连不上的 提示 本地计算机上的mysql服务启动停止后 某些服务在未由其他服务或程序使用时将自动停止 折腾半天网上各种查 才找到方法完美解决 数据库恢复 步骤1 删除原有mysql服务 mysqld remove lt 你的m

随机推荐

  • 面向应用学习stm32(7)-TIM通用定时器-PWM输出和输入

    前导 本文的目的与 意在于面向应用的学习单片机 故不会涉及太多的原理知识 例如寄存器之类的 主要目的在于面向应用的学习单片机 学会单片机的基础用法 开发板采取野火的指南者f103 作者大二小白 写的不好的地方轻点喷 欢迎评论区交流 全部工程
  • 常见的函数式接口介绍

    Supplier
  • linux环境下python编程指南,在Linux系统中搭建Python编程环境

    Linux系统是为编程而设计的 因此在大多数Linux计算机中都默认安装了Python 1 检查Python版本 在系统中运行应用程序Terminal 如果是Ubuntu 可按Ctrl Alt T 打开终端窗口 通过执行python 注意是
  • [学习opencv]彩色图像通道分离、合成

    将彩色图像RGB三色分离出来是一个很有意义的操作 用到void split const Mat mtx vector
  • LeetCode 61. 旋转链表

    题目链接 61 旋转链表 一共用两个指针 第一个记录链表的尾节点 第一次遍历记录链表的长度 由于k可能很大所有我们对k模上n 第二次遍历记录我们用第二个指针记录要翻转的第二段链表的前一个节点 然后用尾部指针tail的next指针指向头结点
  • (超详细)基于MTCNN+FaceNet实现人脸识别及轻量级网络探索和改进(附数据集及预训练模型)

    目录 一 原理分析 二 FcaeNet源码使用 三 爬虫自制数据集 提高亚洲人脸准确率 四 加载预训练模型 五 加载自制pairs验证文档 六 轻量级网络研究 七 FaceNet扩展使用 FaceNet SVM KNN或者改进损失函数 八
  • Windows安装Linux虚拟机详细教程

    文章目录 一 VirtualBox安装 二 Vagrant安装 三 Centos安装 在线安装 离线安装 四 启动与连接 方式一 命令行登入 无需输入用户名与密码 方式二 Xshell登入 自定义ip 一 VirtualBox安装 下载 V
  • raid技术快速入门

    RAID技术 简介 raid全称为Redundant Arrays of Independent Drives 即磁盘冗余阵列 这是由多块独立磁盘 多为硬盘 组合的一个超大容量磁盘组 Raid技术意图在于把多个独立的磁盘设备组成一个容量更大
  • linux配置网络yum源

    yum是Linux环境安装软件包的一种方式 yum仓库用来存放所有的现有的 rpm包 当使用yum安装一个rpm包时 需要依赖关系 会自动在仓库中查找依赖软件并安装 yum仓库可以是本地的 也可以是HTTP FTP nfs形式的网络仓库 简
  • 运放的相位补偿 ?

    两个作用 1 改变反馈网络相移 补偿运放相位滞后 2 补偿运放输入端电容的影响 其实最终还是补偿相位 因为我们所用的运放都不是理想的 一般实际使用的运算放大器对一定频率的信号都有相应的相移作用 这样的信号反馈到输入端将使放大电路工作不稳定甚
  • 华为***技术一:L2TP概述

    原理名词解析 VPDN VPDN是承载PPP报文的 可以为企业 小型ISP 移动办公人员提供接入服务 NAS NAS网络接入服务器 Network Access Server 主要由ISP维护 连接拨号网络 是距离PPP终端地理位置最近的接
  • 小程序开发一个多少钱啊

    在今天的数字化时代 小程序已经成为一种非常流行的应用程序形式 由于它们的便捷性 易用性和多功能性 小程序吸引了越来越多的用户和企业 但是 很多人在考虑开发一个小程序时 都会遇到同一个问题 开发一个小程序需要多少钱 小程序的开发费用因人而异
  • angular 使用前端代理方式实现请求跨域,解决代理不生效问题!!

    最近玩angular 在使用代理方式进行前端跨域处理时 一直无法代理成功 查了许多资料 发现所有angular跨域教程都不完整 下边为大家奉上完整版的跨域操作 1 在项目根目录下定义proxy config json文件 2 在第1步刚刚创
  • 关于在Python的for循环中改变列表的值问题探究

    案例一 def test a 1 2 for i in a print i id a if i 5 break a a 0 2 a 1 2 输出 1 4313456192 2 4313269056 解释 在for循环语句中的变量a使用的内存
  • pwnable.tw - orw

    简单概览 与 start 不同 该程序使用动态链接 提示仅允许有限的系统调用 open read write 函数 程序运行 哪怕是输入一个字母 程序仍然会出现段错误 检查安全措施 可见栈上开了 CANARY 程序 在 IDA 中反编译可见
  • 数据结构 第三章 栈与队列

    栈 Stack 定义 限定仅在表尾进行插入和删除操作的线性表 即后进先出的线性表 Last In First Out 表尾即栈顶top 表头即栈低bottom 存储方式 顺序栈 链栈 顺序栈 一组地址连续的存储单元 一次存放自栈低到栈顶的数
  • 一种高效且节约内存的聚合数据结构的实现

    一种高效且节约内存的聚合数据结构的实现 在特定的场景中 特殊定制数据结构能够得到更加好的性能且更节约内存 聚合函数GroupArray的问题 GroupArray聚合函数是将分组内容组成一个个数组 例如下面的例子 SELECT groupA
  • qt实现侧边导航栏_UI设计干货分享:设计语言 - 侧边导航栏/分页

    原文作者 罗耀 UI 侧边导航栏 分页 步骤条的绘制方法 不管是做设计 感性 还是设计规范 理性 都是仁者见仁智者见智的 都很主观 我是想阐述出自己的想法供大家参考 文章中的数值也不是固定标准 还是希望大家根据不同的项目需求 去解决不同的实
  • Elasticsearch 基于logstash 同步MySQL8 数据

    概述 在生成业务常有将MySQL数据同步到ES的需求 如果需要很高的定制化 往往需要开发同步程序用于处理数据 但没有特殊业务需求 官方提供的logstash就很有优势了 在使用logstash我们应先了解其特性 再决定是否使用 无需开发 仅
  • Tensorflow之MNIST手写数字识别:分类问题(1)

    一 MNIST数据集读取 one hot 独热编码 独热编码是一种稀疏向量 其中 一个向量设为1 其他元素均设为0 独热编码常用于表示拥有有限个可能值的字符串或标识符 优点 1 将离散特征的取值扩展到了欧式空间 离散特征的某个取值就对应欧式