深度学习手记(七)之MNIST实现CNN模型

2023-11-08

  手写字体识别是一个很好练习CNN框架搭建的数据集。下面简单讲述一下整个模型构建的思路:
在这里插入图片描述  整个模型通过两次卷积、两次亚采样以及两次全连接层,整个结构比较简单,也易理解。其中,两次卷积层的大小都为5x5,过滤器分别为32和64个,为了不改变图片的大小,设置padding参数为“same”,步长为1,激活函数为Relu;两次亚采样层(Pool)的大小都为2x2,步长设为2,以至于图片尺寸缩小一倍。通过两次卷积两次亚采样之后,图像的维度就变为了7x7x64,再经过两次全连接层以及通过softmax激活函数之后与十种类别相匹配,预测手写字体的数字。
  分别使用TensorFlow和Keras实现代码

**

1. TensorFlow

**

import numpy as np
import tensorflow as tf
# 下载MNIST数据
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("mnist_data", one_hot=True)

# 使用占位符创建输入变量
# None 表示张量(Tensor)的第一个维度可以是任何长度
# 除以 255 是为了做 归一化(Normalization),把灰度值从 [0, 255] 变成 [0, 1] 区间
# 归一话可以让之后的优化器(optimizer)更快更好地找到误差最小值
input_x = tf.placeholder("float32", [None, 28*28]) / 255.
output_y = tf.placeholder("int32", [None, 10])
# -1 表示自动推导维度大小。让计算机根据其他维度的值
# 和总的元素大小来推导出 -1 的地方的维度应该是多少
input_x_images = tf.reshape(input_x, [-1, 28, 28, 1])

# MNIST测试数据集(选取其中3000张)
test_x = mnist.test.images[:3000]
test_y = mnist.test.labels[:3000]

# 构建卷积神经网络
# # 第一层卷积
conv1 = tf.layers.conv2d(
    inputs=input_x_images,  # 形状 [28, 28, 1]
    filters=32,             # 32 个过滤器,输出的深度(depth)是32
    kernel_size=[5, 5],     # 过滤器在二维的大小是 (5 * 5)
    strides=1,              # 步长是 1
    padding='same',         # same 表示输出的大小不变,因此需要在外围补零 2 圈
    activation=tf.nn.relu   # 激活函数是 Relu
)  # 形状 [28, 28, 32]
# # 第一层池化层(亚采样)
pool1 = tf.layers.max_pooling2d(
    inputs=conv1,      # 形状 [28, 28, 32]
    pool_size=[2, 2],  # 过滤器在二维的大小是(2 * 2)
    strides=2          # 步长是 2
)  # 形状 [14, 14, 32]
# # 第 2 层卷积
conv2 = tf.layers.conv2d(
    inputs=pool1,          # 形状 [14, 14, 32]
    filters=64,            # 64 个过滤器,输出的深度(depth)是64
    kernel_size=[5, 5],    # 过滤器在二维的大小是 (5 * 5)
    strides=1,             # 步长是 1
    padding='same',        # same 表示输出的大小不变,因此需要在外围补零 2 圈
    activation=tf.nn.relu  # 激活函数是 Relu
)  # 形状 [14, 14, 64]
# # 第 2 层池化(亚采样)
pool2 = tf.layers.max_pooling2d(
    inputs=conv2,      # 形状 [14, 14, 64]
    pool_size=[2, 2],  # 过滤器在二维的大小是(2 * 2)
    strides=2          # 步长是 2
)  # 形状 [7, 7, 64]
# # 平坦化(flat)。降维
flat = tf.reshape(pool2, [-1, 7 * 7 * 64])  # 形状 [7 * 7 * 64, ]
# # 1024 个神经元的全连接层
dense = tf.layers.dense(inputs=flat, units=1024, activation=tf.nn.relu)
# # Dropout : 丢弃 50%(rate=0.5)
dropout = tf.layers.dropout(inputs=dense, rate=0.5)
# 10 个神经元的全连接层,这里不用激活函数来做非线性化了
logits = tf.layers.dense(inputs=dropout, units=10)  # 输出。形状 [1, 1, 10]

# 卷积神经网络的优化过程
# # 计算误差(先用 Softmax 计算百分比概率,再用 Cross entropy(交叉熵)来计算百分比概率和对应的独热码之间的误差)
loss = tf.losses.softmax_cross_entropy(onehot_labels=output_y, logits=logits)
# # Adam 优化器来最小化误差,学习率 0.001
train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
# # 精度。计算 预测值 和 实际标签 的匹配程度
# # 返回 (accuracy, update_op), 会创建两个局部变量
accuracy = tf.metrics.accuracy(
    labels=tf.argmax(output_y, axis=1),
    predictions=tf.argmax(logits, axis=1))[1]
# # 创建会话
sess = tf.Session()
# # 初始化变量:全局和局部
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
# # 变量生效
sess.run(init)

# 训练过程
# # 训练 5000 步。这个步数可以调节
for i in range(5000):
    batch = mnist.train.next_batch(50)  # 从 Train(训练)数据集里取 “下一个” 50 个样本
    train_loss, train_op_ = sess.run([loss, train_op], {input_x: batch[0], output_y: batch[1]})
    if i % 100 == 0:
        test_accuracy = sess.run(accuracy, {input_x: test_x, output_y: test_y})
        print("第 {} 步的训练损失={:.4f}, 测试精度={:.2f}".format(i, train_loss, test_accuracy))

# 测试过程
# # 测试:打印 20 个预测值 和 真实值
test_output = sess.run(logits, {input_x: test_x[:20]})
inferred_y = np.argmax(test_output, 1)
print(inferred_y, '推测的数字')  # 推测的数字
print(np.argmax(test_y[:20], 1), '真实的数字')  # 真实的数字
# # 关闭会话
sess.close()

**

2. Keras

**

import numpy as np
from keras.datasets import mnist
from keras.models import Sequential
from keras.utils import np_utils
from keras.layers import Dense, Dropout, Convolution2D, MaxPooling2D, Flatten
from keras.optimizers import Adam
# 载入数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# (60000, 28, 28) -> (60000, 28, 28, 1)
x_train = x_train.reshape(-1, 28, 28, 1)/255.0
x_test = x_test.reshape(-1, 28, 28, 1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train, num_classes=10)
y_test = np_utils.to_categorical(y_test, num_classes=10)
# 定义顺序模型
model = Sequential()
# 第一个卷积层
# input_shape 输入平面
# filters卷积核/滤波器个数
# kernel_size卷积窗口大小
# strides步长
# padding方式
# activation激活函数
model.add(Convolution2D(
    input_shape = (28, 28, 1),
    filters = 32,
    kernel_size = 5,
    strides = 1,
    padding = "same",
    activation = "relu"))
# 第一个池化层
model.add(MaxPooling2D(
    pool_size = 2,
    strides = 2,
    padding = 'same',
))
# 第二个卷积层
model.add(Convolution2D(64,5,strides=1,padding='same',activation = 'relu'))
# 第二个池化层
model.add(MaxPooling2D(2,2,'same'))
# 把第二个池化层的输出扁平化为1维
model.add(Flatten())
# 第一个全连接层
model.add(Dense(1024,activation = 'relu'))
# Dropout
model.add(Dropout(0.5))
# 第二个全连接层
model.add(Dense(10,activation='softmax'))
# 定义优化器
adam = Adam(lr=1e-4)
# 定义优化器,loss function,训练过程中计算准确率
model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=10)
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
print('test loss',loss)
print('test accuracy',accuracy)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习手记(七)之MNIST实现CNN模型 的相关文章

随机推荐

  • servlet跳转报错404(写的验证码显示不出来)

    原因 没有配置web xml
  • 如何在嵌入式LINUX中增加自己的设备驱动程序

    http linux chinaunix net bbs thread 138124 1 2 html 驱动程序的使用可以按照两种方式编译 一种是静态编译进内核 另一种是编译成模块以供动态加载 由于 uClinux不支持模块动态加载 而且嵌
  • Windows下使用nmake编译C/C++的makefile

    现在大多时候在Linux上做服务器端开发 使用VC的机会少了很多 VC编程时习惯上会间个小工程去测试一些小段代码 确保正确后在移植到真正的工程上去 自觉这是个好习惯 决定继续沿用 公司开发环境不提供VC 自己也懒得动用复杂的IDE 想想还是
  • SQL server查询本条数据的下一条数据,上一条数据,及其对应的值,SQL语句示例。

    1 创建测试表 IF EXISTS SELECT FROM sys all objects WHERE object id OBJECT ID N dbo testA AND type IN U DROP TABLE dbo testA G
  • Java编程练习题:1.判断一个整数是奇数还是偶数,至少有两种方式实现,2.输入一个数,判断这个是2的指数,3.两个选择题

    目录 1 判断一个整数是奇数还是偶数 至少有两种方式实现 1 1 方法一 取模法 1 2 方法二 使用按位与 运算符 2 输入一个数 判断这个是2的指数 3 考察逻辑运算符和位运算符选择题 3 1 下列哪一项是 4是奇数或 9为正数 的否定
  • 华为OD机试真题- 组合出合法最小数【2023Q1】【JAVA、Python、C++】

    题目描述 给一个数组 数组里面都是代表非负整数的字符串 将数组里所有的数值排列组合拼接起来组成一个数字 输出拼接成的最小的数字 输入描述 一个数组 数组不为空 数组里面都是代表非负整数的字符串 可以是0开头 例如 13 045 09 56
  • 哇塞,可以用Python实现电脑自动写小说了!!!

    作家 是多么一个让人感到向往的职业 我也幻想着 有一天能够靠写小说赚稿费 来实现自己的另一份可靠的收入 可惜 理想是美好的 但现实很残酷 不管怎么写 都不能赶上其他作者 自己至今仍然是一个扑街写手 我自知我的水平是真的不能冠以作家的称号 因
  • 【简单工具】BurpSuite截获请求并生成文件

    目录 1 实验目标 2 实验环境及靶机设置 2 1 实验环境 2 2 靶机设置 3 实验过程 3 1 前期准备 3 2 BurpSuite设置与操作 3 3 查看结果 4 总结 1 实验目标 设置BurpSuite为浏览器代理 拦截浏览器的
  • Django-登录demo

    本demo的登陆逻辑 如果账号密码正确 跳转至百度页面 账号密码错误 提示登录失败 正确的 账号 admin 密码 123 1 views下添加一个login方法 2 urls中去绑定一下 3 创建一个login xml 运行一下
  • AD18导入的3D模型颜色是白色解决

    问题描述 从Solidworks导入AD18的step文件 显示为白色 解决方法 Solidworks中保存step文件时选AP214格式 不要选择AP203 建议 重新导出时建议起一个和上一次不一样的名字 不然重新导入AD依然是白色 效果
  • Unity自带IAP插件使用

    Unity Services里的Unity IAP对于IOS和GooglePlay的支付用这个插件就足够了 Unity官方文档 1 集成插件 Window Services Ctrl 0 在Services面板Link你的工程 启用In A
  • 抽象问题方法论

    文章目录 模型简化 问题分解到base 流式处理 只关心当前节点问题 从设计者角度出发 思考问题 前后逻辑串联 穷举 细节是魔鬼 基础无穷尽 更高的秩序意味着更先进的文明 设计要小而美 而不是大而全 升维 降维 基于以上逻辑 需要做熵减行为
  • uniapp小程序

    uniapp小程序 uni app之响应式单位upx和rpx upx rpx简介 upx 1 动态绑定的 style 不支持直接使用 upx 2 使用 uni upx2px Number 转换为 px 后再赋值 rpx responsive
  • 华为OD机试 - N进制减法(Java)

    题目描述 主管期望你实现一个基于字符串的N进制的减法 需要对输入的两个字符串按照给定的N进制进行减法操作 输出正负符号和表示结果的字符串 输入描述 输入有三个参数 第一个参数是整数形式的进制N值 N值范围为大于等于2 小于等于35 第二个参
  • 【视频篇】创作的基石,如何找素材?

    前言 工作学习中免不了要搜集素材 然后进行二次创作 这些素材从哪来呢 别告诉我你还在直接百度之后慢慢翻 针对如何找素材 我在打算做一个专题分享一下我的 路子 常见的素材类型比如图片 视频 字体 海报模板 PPT模板等等 想到什么写什么吧 这
  • 一个人如何做抖音矩阵

    随着抖音发展的越来越成熟 不少企业 公司都开始在抖音上发力 但由于人员不够迟迟没有开始布局抖音矩阵 今天小编就来和大家聊一聊一个人怎么做抖音矩阵 一个人做抖音矩阵其实也非常简单 只需要借助矩阵管理系统即可 很多小伙伴迟迟没有做抖音矩阵营销的
  • python 字符串截取_python字符串截取、查找、分割

    Python 截取字符串使用 变量 头下标 尾下标 就可以截取相应的字符串 其中下标是从0开始算起 可以是正数或负数 下标可以为空表示取到头或尾 例1 字符串截取 str 12345678 print str 0 1 gt gt 1 输出s
  • QT 控件重绘

    前言 转载请附上连接 本帖原创请勿照抄 QT重绘控件是指通过实现控件头文件 使用QSS或者样式表来对某个控件进行重新绘制 1 重绘QButton按钮 2 重绘QComboBox下拉框 3 其它控件重绘的办法 1 重绘QButton 重绘控件
  • 竞赛选题 基于机器视觉的二维码识别检测 - opencv 二维码 识别检测 机器视觉

    文章目录 0 简介 1 二维码检测 2 算法实现流程 3 特征提取 4 特征分类 5 后处理 6 代码实现 5 最后 0 简介 优质竞赛项目系列 今天要分享的是 基于机器学习的二维码识别检测 opencv 二维码 识别检测 机器视觉 该项目
  • 深度学习手记(七)之MNIST实现CNN模型

    手写字体识别是一个很好练习CNN框架搭建的数据集 下面简单讲述一下整个模型构建的思路 整个模型通过两次卷积 两次亚采样以及两次全连接层 整个结构比较简单 也易理解 其中 两次卷积层的大小都为5x5 过滤器分别为32和64个 为了不改变图片的