基于tensorflow2.3.0的手写数字识别案例

2023-11-06

本程序使用mnist训练数据集进行训练得出模型,再利用mnist测试数据集进行验证,得出模型的实际效果。

1、引入运行需要的环境

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

2、读取mnist数据集

(train_image, train_labels), (test_image, test_labels) = tf.keras.datasets.mnist.load_data()

如果在线下载数据集,速度非常的慢。我是提前下载号mnist数据集(mnist.npz),并放到~/.keras/datasets目录下,上面的代码便直接使用该目录下的直接读取。(train_image, train_labels)存放训练数据,(test_image, test_labels)存放测试数据。

在jupyter notebook环境下可以直接使用train_image.shape来查看数据集的形状,结果为(60000, 28, 28),即训练集共60000张图片,每张图片的长宽为28*28。输入代码plt.imshow(train_image[1])可以显示第2张图片,如下:

 

train_labels是train_image对应的标签,比如输出train_labels[1]的结果为<tf.Tensor: shape=(), dtype=int64, numpy=0>。

3、扩展图片数据的维度

目前图片数据的维度为3,(60000, 28, 28)表示共60000张图片,每张图片的长宽为28*28。由于该程序使用卷积神经网络,因此我将维度扩展为4维(对每张图片添加“厚度”维度)。如下代码:

train_image = tf.expand_dims(train_image, -1)
test_image = tf.expand_dims(test_image, -1)

输入代码train_image.shape后,显示TensorShape([60000, 28, 28, 1])。

4、类型转换、数据归一化

#将train_image中的每个元素转换为float类型,同时进行归一化
train_image = tf.cast(train_image/255, tf.float32)
test_image = tf.cast(test_image/255, tf.float32)
train_labels = tf.cast(train_labels, tf.int64)
test_labels = tf.cast(test_labels, tf.int64)

5、图片和标签进行绑定

dataset = tf.data.Dataset.from_tensor_slices((train_image, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_image, test_labels))

输入代码dataset后,显示<BatchDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int64)>。

6、乱序、分组

dataset = dataset.shuffle(10000).batch(32)
test_dataset = test_dataset.batch(32)

将dataset进行乱序处理,test_dataset不用进行乱序处理,因为test_dataset不参与训练,只进行验证。将数据集batch处理(划分批),这里32个元素为1批。经过batch处理之后,就可以每次处理1批,比如下面代码可以取下一批的数据:

features, labels = next(iter(dataset))

输入代码features.shape,输出为TensorShape([32, 28, 28, 1]),即1次取道32张图片。

7、定义模型、优化器、损失函数

#定义模型
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(16, [3,3], activation='relu', input_shape=(None, None, 1)),  #input_shape=(None, None, 1)表示channel为1的任意大小的图片均可作为输入
    tf.keras.layers.Conv2D(32, [3,3], activation='relu'), 
    tf.keras.layers.GlobalMaxPooling2D(),
    tf.keras.layers.Dense(10)
])
#优化器
optimizer = tf.keras.optimizers.Adam()
#损失函数
#由于标签为数字型编码,不是独热型编码,所以采用SparseCategoricalCrossentropy
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

输入model.summary()可以查看模型的结构,结构如下图:

8、先验

虽然目前只定义了该模型,并没有对其进行训练,但是依然可以调用该模型并输出结果。原因是对模型训练的实质上是不断优化模型中的参数,使其接近理想值的过程。而对于刚定义的模型,存在初始状态的参数(虽然参数值不理想),但依然可以调用该模型。

通过代码predictions = model(features)可以调用模型,输出预测结果为TensorShape([32, 10])。现在解释一下该输出结果:(1)features为32张图片,每张图片1个预测结果,所以输出结果中有32;(2)model定义中的最后部分,全连接层输出节点数为10,所以输出结果中有10。

9、定义准确率、训练误差测量变量

train_loss = tf.keras.metrics.Mean('train_loss')  #可以上网查一下metrics.Mean的使用方法,就知道了
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')  #是数字编码,而非独热编码
test_loss = tf.keras.metrics.Mean('test_loss')  #可以上网查一下metrics.Mean的使用方法,就知道了
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy')  #是数字编码,而非独热编码

10、定义批训练函数

我们之前对数据集划分了batch,每32个元素为1组,那么我们在训练的时候,每读入32个元素,就更新一次参数。

def train_step(model, images, labels):
    with tf.GradientTape() as t:
        pred = model(images)  #pred为预测结果
        loss_step = loss_func(labels, pred)  #得出损失值
    grads = t.gradient(loss_step, model.trainable_variables) #根于损失值,对训练参数计算梯度
    optimizer.apply_gradients(zip(grads, model.trainable_variables))  #根据梯度,更新训练参数
    train_loss(loss_step) #计算平均loss值
    train_accuracy(labels, pred) #计算预测准确度

其中,model为定义的网络模型,images为32张图片,labels为32张图片对应的32个标签。

11、定义批验证函数

为了衡量当前训练的模型在测试数据集上的预测能力,需要定义批验证函数。

def test_step(model, images, labels):
    pred = model(images)  #对images进行预测
    loss_step = loss_func(labels, pred) #计算预测值和真实值之间的损失
    test_loss(loss_step)  #计算损失值的平均值
    test_accuracy(labels, pred)  #计算预测准确度

12、定义总训练函数

在定义了批训练函数和批验证函数的基础上,可以定义总训练函数了。

def train():
    for epoch in range(10):  #对整个数据集循环10次
        for(batch, (images, labels)) in enumerate(dataset): #每次从训练数据集中取一批数据(32个)进行训练
            train_step(model, images, labels) #训练
        print('Epoch{} loss is {}, accuracy is {}'.format(epoch, train_loss.result(), train_accuracy.result()))  #输出训练数据集上的损失值和准确度
        for(batch, (images, labels)) in enumerate(test_dataset):  #每次从测试数据集中取一批数据(32个)进行测试
            test_step(model, images, labels)  #测试
        print('Epoch{} test_loss is {}, test_accuracy is {}'.format(epoch, test_loss.result(), test_accuracy.result()))  #输出测试数据集上的损失值和准确度
        
        #每当遍历完1次整个数据集之后都需要进行重置,具体上网查找metrics.Mean的用法
        train_loss.reset_states() 
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()

13、调用训练函数进行训练

输入代码train()就可以在数据集上进行训练了,由于在train函数中包含输出训练过程中的中间值(loss、准确度)的代码,所以在训练过程中就可以掌握训练的情况,如下。

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于tensorflow2.3.0的手写数字识别案例 的相关文章

  • 中断 Select 以添加另一个要在 Python 中监视的套接字

    我正在 Windows XP 应用程序中使用 TCP 实现点对点 IPC 我正在使用select and socketPython 2 6 6 中的模块 我有三个 TCP 线程 一个读取线程通常会阻塞select 一个通常等待事件的写入线程
  • 使用特定的类/函数预加载 Jupyter Notebook

    我想预加载一个笔记本 其中包含我在另一个文件中定义的特定类 函数 更具体地说 我想用 python 来做到这一点 比如加载一个配置文件 包含所有相关的类 函数 目前 我正在使用 python 生成笔记本并在服务器上自动启动它们 因为不同的
  • 在 django ORM 中查询时如何将 char 转换为整数?

    最近开始使用 Django ORM 我想执行这个查询 select student id from students where student id like 97318 order by CAST student id as UNSIG
  • 如何使用 opencv.omnidir 模块对鱼眼图像进行去扭曲

    我正在尝试使用全向模块 http docs opencv org trunk db dd2 namespacecv 1 1omnidir html用于对鱼眼图像进行扭曲处理Python 我正在尝试适应这一点C 教程 http docs op
  • 用枢轴点拟合曲线 Python

    我有下面的图 我想用 2 条线来拟合它 使用 python 我设法适应上半部分 def func x a b x np array x return a x b popt pcov curve fit func up x up y 我想用另
  • 使用 Python 从文本中删除非英语单词

    我正在 python 上进行数据清理练习 我正在清理的文本包含我想删除的意大利语单词 我一直在网上搜索是否可以使用像 nltk 这样的工具包在 Python 上执行此操作 例如给出一些文本 Io andiamo to the beach w
  • 删除flask中的一对一关系

    我目前正在使用 Flask 开发一个应用程序 并且在删除一对一关系中的项目时遇到了一个大问题 我的模型中有以下结构 class User db Model tablename user user id db Column db String
  • Pandas 日期时间格式

    是否可以用零后缀表示 pd to datetime 似乎零被删除了 print pd to datetime 2000 07 26 14 21 00 00000 format Y m d H M S f 结果是 2000 07 26 14
  • 立体太阳图 matplotlib 极坐标图 python

    我正在尝试创建一个与以下类似的简单的立体太阳路径图 http wiki naturalfrequent com wiki Sun Path Diagram http wiki naturalfrequency com wiki Sun Pa
  • datetime.datetime.now() 返回旧值

    我正在通过匹配日期查找 python 中的数据存储条目 我想要的是每天选择 今天 的条目 但由于某种原因 当我将代码上传到 gae 服务器时 它只能工作一天 第二天它仍然返回相同的值 例如当我上传代码并在 07 01 2014 执行它时 它
  • 如何在不丢失注释和格式的情况下更新 YAML 文件 / Python 中的 YAML 自动重构

    我想在 Python 中更新 YAML 文件值 而不丢失 Python 中的格式和注释 例如我想改造 YAML 文件 value 456 nice value to value 6 nice value 界面类似于 y yaml load
  • Docker 中的 Python 日志记录

    我正在 Ubuntu Web 服务器上的 Docker 容器中测试运行 python 脚本 我正在尝试查找由 Python Logger 模块生成的日志文件 下面是我的Python脚本 import time import logging
  • 如何通过 TLS 1.2 运行 django runserver

    我正在本地 Mac OS X 机器上测试 Stripe 订单 我正在实现这段代码 stripe api key settings STRIPE SECRET order stripe Order create currency usd em
  • 加快网络抓取速度

    我正在使用一个非常简单的网络抓取工具抓取 23770 个网页scrapy 我对 scrapy 甚至 python 都很陌生 但设法编写了一个可以完成这项工作的蜘蛛 然而 它确实很慢 爬行 23770 个页面大约需要 28 小时 我看过scr
  • javascript 是否有等效的 __repr__ ?

    我最接近Python的东西repr这是 function User name password this name name this password password User prototype toString function r
  • 仅第一个加载的 Django 站点有效

    我最近向 stackoverflow 提交了一个问题 标题为使用mod wsgi在apache上多次请求后Django无限加载 https stackoverflow com questions 71705909 django infini
  • 如何在 pygtk 中创建新信号

    我创建了一个 python 对象 但我想在它上面发送信号 我让它继承自 gobject GObject 但似乎没有任何方法可以在我的对象上创建新信号 您还可以在类定义中定义信号 class MyGObjectClass gobject GO
  • 使用for循环时如何获取前一个元素? [复制]

    这个问题在这里已经有答案了 可能的重复 Python 循环内的上一个和下一个值 https stackoverflow com questions 1011938 python previous and next values inside
  • Django-tables2 列总计

    我正在尝试使用此总结列中的所有值文档 https github com bradleyayers django tables2 blob master docs pages column headers and footers rst 但页
  • 更改 Tk 标签小部件中单个单词的颜色

    我想更改 Tkinter 标签小部件中单个单词的字体颜色 我知道可以使用文本小部件来实现与我想要完成的类似的事情 例如使单词 YELLOW 显示为黄色 self text tag config tag yel fg clr yellow s

随机推荐

  • 通过uvm_printer的print_generic进行扩展打印

    uvm的field automation机制实现的其中一项功能就是sprint功能 该函数通过调用do print函数实现 在某些情况的 uvm的打印功能不是我们所期望的 比如多维数组的field automation机制就不支持 stru
  • k8s集群部署(rke + rancher)

    部署环境说明 cat etc redhat release CentOS Linux release 7 9 2009 Core 一 使用rke命令安装 k8s集群 1 在所有节点上安装chronyd服务 yum y install chr
  • 【LeetCode3】无重复字符的最长子串(滑动窗口)

    窗口维护的是无重复字符的最长子串 c int lengthOfLongestSubstring string s vector
  • linux下挂载和卸载cdrom

    1 查询块设备及mount位置 root slave143 lsblk NAME MAJ MIN RM SIZE RO TYPE MOUNTPOINT sr0 11 0 1 3 6G 0 rom type rom表示sr0为 cdrom设备
  • Java生成某段时间内的随机时间

    上代码 1 import java text SimpleDateFormat 2 import java util Date 3 4 public class DateUtil 5 6 7 生成随机时间 8 9 param beginDa
  • Linux部署vue项目

    一 nginx conf配置文件位置 etc nginx nginx conf 二 nginx的常用命令 1 启动 Nginx start nginx 或 systemctl start nginx 2 关闭 Nginx nginx s s
  • 【2023最全最新教程】RobotFramework的介绍与环境搭建(超详细~)

    本文使用的环境 win10系统 python3 6 一 RobotFramework介绍 1 1 框架基本介绍 1 Robot Framework 简称RF 是基于python编写的 开源的 功能自动化框架 2 RF是一款关键字驱动的测试框
  • STM32外设芯片驱动学习记录 —— (一) BH1750光照传感器驱动开发

    目录 一 芯片介绍 二 Datasheet解读 1 硬件说明 2 寄存器说明 3 通信过程 三 驱动代码编写 1 软件I2C驱动 2 BH1750芯片驱动函数 总结 一 芯片介绍 BH1750是16位数字输出型 环境光强度传感器集成电路 使
  • VanillaNet实战:使用VanillaNet实现图像分类(二)

    文章目录 训练部分 导入项目使用的库 设置随机因子 设置全局参数 图像预处理与增强 读取数据 设置Loss 设置模型 设置优化器和学习率调整算法 设置混合精度 DP多卡 EMA 定义训练和验证函数 训练函数 验证函数 调用训练和验证方法 运
  • 1.1python中print的使用方法

    1 对于初学者开始学习python 首先应该学会的就是对python中的print用法 学习一个函数 首先需要知道该函数的使用方法 使用参数以及使用后的结果 本文以pycharm解释器对python中函数print 做出以下解释 1 打开p
  • 赣榆高中2021高考成绩查询,赣榆中考成绩查询2021

    2021赣榆中考成绩查询时间方法 91中考网消息 2021年赣榆中考即将开始 在中考后 广大考生最关心的无疑就是中考成绩查询方法 赣榆中考成绩什么时候公布 根据往年经验 小编收集整理了2021赣榆中考成绩查询时间方法 具体如下 2021赣榆
  • 数字黑洞 C语言

    题目 给定任一个各位数字不完全相同的 4 位正整数 如果我们先把 4 个数字按非递增排序 再按非递减排序 然后用第 1 个数字减第 2 个数字 将得到一个新的数字 一直重复这样做 我们很快会停在有 数字黑洞 之称的 6174 这个神奇的数字
  • java: javamail 1.6.2 using jdk 19

    版权所有 2022 涂聚文有限公司 许可信息查看 描述 数据库 Ms SQL server 2019 IDE Eclipse IDE for Enterprise Java and Web Developers 2021 09 OS Win
  • Vue中打包压缩插件:compression-webpack-plugin

    1 http gzip 介绍 Encoding type gzip GNU zip 压缩格式 也是互联网上最流行的压缩格式 deflate zlib deflate 压缩格式 流行程度仅次于 gzip br 一种专门为 HTTP 优化的新压
  • Jmeter集合点

    一 集合点简介 1 我们怎么实现真正的并发 并发 指的是系统中真正操作业务的用户 在jmeter中 称为线程数 jmeter中 各个线程 用户 在进行业务操作中的顺序存在一定的随机性 2 集合点的目的 让各个线程 用户 步调一致 对系统进行
  • 小记跨域相关问题

    注解 CrossOrigin 支持跨域 跨域 不同的域名A 访问 域名B 的数据就是跨域 端口不同 也是跨域 loalhost 18081 gt localhost 18082 协议不同 也是跨域 域名不同 也是跨域 协议一直 端口一致 域
  • Verilog小心得

    一 概念 阻塞赋值 在always过程块中 当存在多条阻塞赋值语句时 在前面的赋值语句没有完成之前 后面的语句就不能被执行 阻塞赋值语句顺序执行 就像被阻塞了一样 因此被称为阻塞赋值 非阻塞赋值 lt 在always过程块中 当存在多条阻塞
  • Golang——从入门到放弃

    文章目录 一 golang 简介 1 go 语言特点 2 go 语言应用领域 3 使用 go 语言的公司有哪些 二 安装 golang 1 golang 下载安装 2 配置环境变量 三 golang 开发工具 1 安装 VSCode 2 下
  • C++ template的使用

    1 template的使用 C 的高级玩法 当然包含了模板 模板 template 是实现代码重用机制的一种工具 它可以实现类型参数化 把类型定义为参数 模板元编程 从而实现了真正的代码可重用性 模板是用来批量生成功能和形式都几乎相同的代码
  • 基于tensorflow2.3.0的手写数字识别案例

    本程序使用mnist训练数据集进行训练得出模型 再利用mnist测试数据集进行验证 得出模型的实际效果 1 引入运行需要的环境 import tensorflow as tf from tensorflow import keras fro