Tensorflow Cnn mnist 的一些细节

2023-05-16

Tensorflow cnn MNIST 笔记

写这个完全是记录看官网example时不懂,但后来弄懂的一些细节。当然这个可以算是对官方文档的补充,也许每个人遇到的不懂都不一样,但希望对大家有帮助。

先上代码

#!/usr/bin/env python
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
def weightVariable(shape):
    initial = tf.truncated_normal(shape,stddev=0.1)
    return tf.Variable(initial)
def biasVariable(shape):
    initial = tf.constant(0.1,shape=shape)
    return tf.Variable(initial)
def conv2d(x,W):
    return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')
def maxPool2x2(x):
    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')##

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
sess = tf.InteractiveSession()

x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
xImage = tf.reshape(x,[-1,28,28,1])

wConv1 = weightVariable([5,5,1,32])
bConv1 = biasVariable([32]);
hConv1 = tf.nn.relu(conv2d(xImage,wConv1)+bConv1)
hPool1 = maxPool2x2(hConv1)

wConv2 = weightVariable([5,5,32,64])
bConv2 = biasVariable([64])
hConv2 = tf.nn.relu(conv2d(hPool1,wConv2)+bConv2)
hPool2 = maxPool2x2(hConv2)

wFc1 = weightVariable([7*7*64,1024])
bFc1 = biasVariable([1024])
h1d = tf.reshape(hPool2,[-1,7*7*64])
hFc1 = tf.nn.relu(tf.matmul(h1d,wFc1)+bFc1)

dropProb = tf.placeholder(tf.float32)
hFc1Drop = tf.nn.dropout(hFc1,dropProb)

wFc2 = weightVariable([1024,10])
bFc2 = biasVariable([10])
y = tf.matmul(hFc1Drop,wFc2)+bFc2

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(20000):
        batch = mnist.train.next_batch(50)
        if i % 100 == 0:
            train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], dropProb: 1.0})
            print('step %d, training accuracy %g' % (i, train_accuracy))
        train_step.run(feed_dict={x: batch[0], y_: batch[1], dropProb: 0.5})

    print('test accuracy %g' % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, dropProb: 1.0}))

## 一些需要看api文档的细节

tf.nn.conv2d(
    input,
    filter,
    strides,
    padding,
    use_cudnn_on_gpu=None,
    data_format=None,
    name=None
)

这个函数是计算卷积的,第一个参数是输入的图片数据,在这个地方具体是一个1*784的vector,然后filter是滤波器,拿第一层卷积层说的话,fileter = wConv1 = weightVariable([5,5,1,32]). 就是说每一个卷积核是5x5的一个矩阵,滑步是1,同时这一层有32个卷积核,其实我也不懂为什么这里要提取32个特征,然后padding=’SAME’,就是说卷积后的图片大小依然是1x784,至于内部是如何做padding的,不用管。这样,经过这层卷积操作,就产生了32张图片,每张都代表原始图片不同的特征

然后就是关于bais的维度的疑问,我发现一些地方bias的维度和他相加的矩阵的维度不一致,我猜测这个相加应该是在最后一个维度上做广播,比如[M,10]+b,b是一个10维的向量,这个可能就是表示矩阵M每一个元素都加上b对应的一个分量

有了以上维度的解读,就可以知道,第一次卷积操作输出了32张不同特征的图片,大小没变。然后做了一个maxpool下采样,图片变小了。然后是第二层卷积操作,wConv2 = weightVariable([5,5,32,64])。可以看到其中第三个参数是32,因为现在有32张图片了,相当于有32个通道。然后第四个参数是64,是说明现在要提取64个特征。接下来做的是下采样,然后就是dropout,防止过拟合。其中过拟合需要有概率数据输入,所以设置了一个placeholder。

然后就是全链接层,发现全链接层其实就是简单的矩阵乘法。在FC1钱,先要将数据reshape,因为现在数据是7x7的图片,有64张,所以变成一个7x7x64的vector,然后由于全链接层有1000个神经元,所以weight的shape是[7x7x764,1000].接着又做了一次dropout,防止过拟合。然后就是第二层全连接层,然后输出10个class。

然后是tensorBoard的细节,下次补充

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow Cnn mnist 的一些细节 的相关文章

随机推荐

  • PostgreSQL教程

    一 PostgreSQL介绍 PostgreSQL是一个功能强大的 开源 的关系型数据库 底层基于C实现 PostgreSQL的开源协议和Linux内核版本的开源协议是一样的 BDS协议 xff0c 这个协议基本和MIT开源协议一样 xff
  • you-get的安装与使用

    youget简介 you get是github上python的一个开源库 https github com soimort you get xff0c 使用you get你只需要取得视频所在网页链接地址就可以很轻松的下载下来 xff0c 目
  • VS Code搭建PYQT5环境并创建Helloworld实例

    使用Python pip安装PyQt5和PyQt5 tool pip install PyQt5 pip install PyQt5 tools 在VS code中安装插件PYQT Integration 配置PYQT Integratio
  • ubuntu Xrdp远程连接Authentication is required to create a color managed device

    问题 Gnome Bug xff1a 无法点击 永不消逝的授权对话框 解决 xff1a https blog csdn net wu weijie article details 108481456
  • Linux运维入门~21.系统磁盘管理,解决u盘连接电脑无反应,解决卸载u盘正忙问题

    本节我们来了解一下linux系统的磁盘管理 识别设备常用命令有 xff1a fdisk l 查看真实存在的设备 cat proc partition 系统识别的设备 blkid 系统可使用的设备 df 系统正在挂载的设备 du 查看磁盘容量
  • 快速幂取模:求 a^b % N(C++)

    在某些情况下 xff0c 我们需要求模 N 情况下某个数的多次幂 xff0c 例如 xff1a 求多次幂结果的最后几位数 RSA算法的加解密 如果底数或者指数很大 xff0c 直接求幂再取模很容易会出现数据溢出的情况 xff0c 产生错误的
  • 新手教程:手把手教你使用Powershell批量修改文件名

    适合完全没用过 xff0c 没了解过powershell的人 1 打开Windows Powershell ISE 在任务栏搜索框中输入ISE xff0c 然后打开 xff08 我的任务栏放在右边了 xff0c 所以是这个样子 xff09
  • Mac OS 开机密码重置

    通过 Mac OS 恢复功能启动 Apple 芯片 xff1a 将 Mac 开机并继续按住电源按钮 xff0c 直至看到启动选项窗口 选择标有 选项 字样的齿轮图标 xff0c 然后点按 继续 Intel 处理器 xff1a 将 Mac 开
  • 表达式求值:从“加减”到“带括号的加减乘除”的实践过程

    本文乃Siliphen原创 xff0c 转载请注明出处 xff1a http blog csdn net stevenkylelee 为什么想做一个表达式求值的程序 最近有一个需求 xff0c 策划想设置游戏关卡的某些数值 xff0c 这个
  • Linux中source命令,在Android build 中的应用

    source命令 xff1a source命令也称为 点命令 xff0c 也就是一个点符号 xff08 xff09 source命令通常用于重新执行刚修改的初始化文件 xff0c 使之立即生效 xff0c 而不必注销并重新登录 用法 xff
  • Edge 错误代码: STATUS_ACCESS_DENIED 解决方案

    1 到C盘Edge的文件全部删掉 2 到电脑管家的软件管理重新下载Edge 或者 去官网下载 3 再次打开Edge xff0c 功能都回来了 注 xff1a 该解决方案源自于edge吧的四川男篮大佬
  • centos7.9离线安装mysql5.7

    前言 windows server2003服务器 xff0c 安装mysql提示需要net framework xff0c 费了半天劲装好了 xff0c 发现解压版redis无法启动 xff0c 换了个低版本也是无法安装 xff0c 服务器
  • vs2019-slicer编译问题记录

    3D Slicer编译过程问题记录 官网教程 xff1a https slicer readthedocs io en latest developer guide build instructions windows html 环境 CM
  • sudo npm command not found 问题解决

    这种情况通常是使用 npm 命令可以正常使用 xff0c 但使用sudo npm 命令便会报 command not found 这是什么原因呢 xff1f 输入which npm可以得到 usr local bin npm xff0c 这
  • (POJ1201)Intervals <差分约束系统-区间约束>

    Intervals Description You are given n closed integer intervals ai bi and n integers c1 cn Write a program that reads the
  • 【sv与c】sv与c交互

    网上此类文章很多 xff0c 这里暂时不放具体实现和测试结果 xff0c 后续持续更新 下面引用一些帖子 xff0c 帖子中涉及到具体做法 vcs联合编译v sv c 43 43 代码 sxlwzl的专栏 CSDN博客 1 xff0c 假设
  • stm32f103c8t6最小系统

    提示 xff1a 文章写完后 xff0c 目录可以自动生成 xff0c 如何生成可参考右边的帮助文档 文章目录 前言stm32f103c8t6构成二 xff1a 电源电路稳压模块注意 复位电路NRST 时钟电路程序下载电路JTAGSWD 启
  • udp中的connect()&bind()

    connect amp bind 的作用 udp udp connect span class hljs preprocessor include lt sys types h gt span span class hljs preproc
  • Linux Hook技术实践

    LInux Hook技术实践 什么是hook 简单的说就是别人本来是执行libA so里面的函数的 xff0c 结果现在被偷偷换成了执行你的libB so里面的代码 xff0c 是一种替换 为什么hook 恶意代码注入调用常用库函数时打lo
  • Tensorflow Cnn mnist 的一些细节

    Tensorflow cnn MNIST 笔记 写这个完全是记录看官网example时不懂 xff0c 但后来弄懂的一些细节 当然这个可以算是对官方文档的补充 xff0c 也许每个人遇到的不懂都不一样 xff0c 但希望对大家有帮助 先上代