Tenserflow学习(二)——MNIST数据集分类三层网络搭建+Dropout+tensorboard可视化

2023-11-18

1、上代码


import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 载入数据
"""one_hot参数把标签转化到0-1之间
"""
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
# 每个批次大小(每次放入训练图像数量)
batch_size = 100
# 批次数量
num_batch = mnist.train.num_examples // batch_size


# 参数概要
def variable_summaries(var):
    with tf.name_scope('summaries'):
        mean = tf.reduce_mean(var)
        # 均值
        tf.summary.scalar('mean', mean)
        with tf.name_scope('stddev'):
            stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
        # 标准差
        tf.summary.scalar('stddev', stddev)
        tf.summary.scalar('max', tf.reduce_max(var))
        tf.summary.scalar('min', tf.reduce_min(var))
        # 直方图
        tf.summary.histogram('histogram', var)


with tf.name_scope('input'):
    # placeholder
    x = tf.placeholder(tf.float32, [None, 784], name='x-input')
    y = tf.placeholder(tf.float32, [None, 10], name='y-input')

with tf.name_scope('keep_prob_And_learning_rate'):
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    lr = tf.Variable(0.001, dtype=tf.float32, name='learning_rate')
    tf.summary.scalar('learning_rate', lr)

with tf.name_scope('layer'):
    with tf.name_scope('net-one'):
        with tf.name_scope('wights-one'):
            w1 = tf.Variable(tf.truncated_normal([784, 1000], stddev=0.1), name='w1')
            variable_summaries(w1)
        with tf.name_scope('biases-one'):
            b1 = tf.Variable(tf.zeros([1000]) + 0.1, name='b1')
            variable_summaries(b1)
        with tf.name_scope('drop-one'):
            L1 = tf.nn.tanh(tf.matmul(x, w1) + b1)
            L1_drop = tf.nn.dropout(L1, keep_prob)
    with tf.name_scope('net-two'):
        with tf.name_scope('wights-two'):
            w2 = tf.Variable(tf.truncated_normal([1000, 500], stddev=0.1), name='w2')
            variable_summaries(w2)
        with tf.name_scope('biases-two'):
            b2 = tf.Variable(tf.zeros([500]) + 0.1, name='b2')
            variable_summaries(b2)
        with tf.name_scope('drop-two'):
            L2 = tf.nn.tanh(tf.matmul(L1_drop, w2) + b2)
            L2_drop = tf.nn.dropout(L2, keep_prob)
    with tf.name_scope('net-three'):
        with tf.name_scope('wights-three'):
            w3 = tf.Variable(tf.truncated_normal([500, 100], stddev=0.1), name='w3')
            variable_summaries(w3)
        with tf.name_scope('biases-three'):
            b3 = tf.Variable(tf.zeros([100]) + 0.1, name='b3')
            variable_summaries(b3)
        with tf.name_scope('drop-three'):
            L3 = tf.nn.tanh(tf.matmul(L2_drop, w3) + b3)
            L3_drop = tf.nn.dropout(L3, keep_prob)
    with tf.name_scope('net-four-output'):
        with tf.name_scope('wights-four-final'):
            w4 = tf.Variable(tf.truncated_normal([100, 10], stddev=0.1), name='w4')
            variable_summaries(w4)
        with tf.name_scope('biases-four-final'):
            b4 = tf.Variable(tf.zeros([10]) + 0.1, name='b4')
            variable_summaries(b4)
        with tf.name_scope('prediction'):
            prediction = tf.nn.softmax(tf.matmul(L3_drop, w4) + b4)  # 概率值转化: softmax()

with tf.name_scope('loss'):
    # loss = tf.reduce_mean(tf.square(y - prediction))  # 二次代价函数
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))  # 交叉熵代价函数
    tf.summary.scalar('loss', loss)
with tf.name_scope('train'):
    # train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    # train_step = tf.train.AdadeltaOptimizer(learning_rate=0.2, rho=0.95).minimize(loss)
    train_step = tf.train.AdamOptimizer(lr).minimize(loss)

init = tf.global_variables_initializer()
with tf.name_scope("accuracy-total"):
    with tf.name_scope("correct_prediction"):
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))  # argmax()返回一维张量中最大值所在的位置
    # 计算准确率
    """cast()将correct_prediction列表变量中的值转换成float32 --> true=1.0,false=0.0
    """
    with tf.name_scope("accuracy"):
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))  # cast()相当于类型转换函数
        tf.summary.scalar('accuracy', accuracy)

# 合并summary
merged = tf.summary.merge_all()
with tf.Session() as sess:
    sess.run(init)
    writer = tf.summary.FileWriter('logs/', sess.graph)
    for epoch in range(51):
        sess.run(tf.assign(lr, 0.001 * (0.95 ** epoch)))  # tf.assign()属于赋值操作
        for batch in range(num_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            summary, _ = sess.run([merged, train_step], feed_dict={x: batch_xs, y: batch_ys, keep_prob: 0.9})
        writer.add_summary(summary, epoch)
        test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0})
        train_acc = sess.run(accuracy, feed_dict={x: mnist.train.images, y: mnist.train.labels, keep_prob: 1.0})
        learning_rate = sess.run(lr)
        print('iter' + str(epoch) + ', testing accuracy:' + str(test_acc) + ', training accuracy:' + str(train_acc)
              + ', learning rate:' + str(learning_rate))


2、dropout原理认识

参考:https://blog.csdn.net/program_developer/article/details/80737724

3、tensorboard可视化

" writer = tf.summary.FileWriter(‘logs/’, sess.graph) "记录了logs存储的路径,在所设置的路径下可以找到这样的文件:
在这里插入图片描述
打开终端进入到该文件位置,输入tensorboard --logdir=.\命令(比如我的文件在D:\PycharmProject\StudyDemo\logs下):
在这里插入图片描述
在chrome中输入最后一行的:http:// … …
得到可视化数据和网络结构
在这里插入图片描述
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tenserflow学习(二)——MNIST数据集分类三层网络搭建+Dropout+tensorboard可视化 的相关文章

  • 保存为 HDF5 的图像未着色

    我目前正在开发一个将文本文件和 jpg 图像转换为 HDF5 格式的程序 用HDFView 3 0打开 似乎图像仅以灰度保存 hdf h5py File Sample h5 img Image open Image jpg data np
  • 为什么从 Pandas 1.0 中删除了日期时间?

    我在 pandas 中处理大量数据分析并每天使用 pandas datetime 最近我收到警告 FutureWarning pandas datetime 类已弃用 并将在未来版本中从 pandas 中删除 改为从 datetime 模块
  • 元组有什么用?

    我现在正在学习 Python 课程 我们刚刚介绍了元组作为数据类型之一 我阅读了它的维基百科页面 但是 我无法弄清楚这种数据类型在实践中会有什么用处 我可以提供一些需要一组不可变数字的示例吗 也许是在 Python 中 这与列表有何不同 每
  • 如何用python脚本控制TP LINK路由器

    我想知道是否有一个工具可以让我连接到路由器并关闭它 然后从 python 脚本重新启动它 我知道如果我写 import os os system ssh l root 192 168 2 1 我可以通过 python 连接到我的路由器 但是
  • Python 中的哈希映射

    我想用Python实现HashMap 我想请求用户输入 根据他的输入 我从 HashMap 中检索一些信息 如果用户输入HashMap的某个键 我想检索相应的值 如何在 Python 中实现此功能 HashMap
  • 跟踪 pypi 依赖项 - 谁在使用我的包

    无论如何 是否可以通过 pip 或 PyPi 来识别哪些项目 在 Pypi 上发布 可能正在使用我的包 也在 PyPi 上发布 我想确定每个包的用户群以及可能尝试积极与他们互动 预先感谢您的任何答案 即使我想做的事情是不可能的 这实际上是不
  • 您可以格式化 pandas 整数以进行显示,例如浮点数的“pd.options.display.float_format”?

    我见过this https stackoverflow com questions 18404946 py pandas formatdataframe and this https stackoverflow com questions
  • 如何在 Python 中解析和比较 ISO 8601 持续时间? [关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 我正在寻找一个 Python v2 库 它允许我解析和比较 ISO 8601 持续时间may处于不同单
  • Python beautifulsoup 仅限 1 级文本

    我看过其他 beautifulsoup 得到相同级别类型的问题 看来我的有点不同 这是网站 我正试图拿到右边那张桌子 请注意表的第一行如何展开为该数据的详细细分 我不想要那个数据 我只想要最顶层的数据 您还可以看到其他行也可以展开 但在本例
  • 在Python中检索PostgreSQL数据库的新记录

    在数据库表中 第二列和第三列有数字 将会不断添加新行 每次 每当数据库表中添加新行时 python 都需要不断检查它们 当 sql 表中收到的新行数低于 105 时 python 应打印一条通知消息 警告 数量已降至 105 以下 另一方面
  • 如何通过 TLS 1.2 运行 django runserver

    我正在本地 Mac OS X 机器上测试 Stripe 订单 我正在实现这段代码 stripe api key settings STRIPE SECRET order stripe Order create currency usd em
  • Numpy - 根据表示一维的坐标向量的条件替换数组中的值

    我有一个data多维数组 最后一个是距离 另一方面 我有距离向量r 例如 Data np ones 20 30 100 r np linspace 10 50 100 最后 我还有一个临界距离值列表 称为r0 使得 r0 shape Dat
  • pip 列出活动 virtualenv 中的全局包

    将 pip 从 1 4 x 升级到 1 5 后pip freeze输出我的全局安装 系统 软件包的列表 而不是我的 virtualenv 中安装的软件包的列表 我尝试再次降级到 1 4 但这并不能解决我的问题 这有点类似于这个问题 http
  • Python3 在 DirectX 游戏中移动鼠标

    我正在尝试构建一个在 DirectX 游戏中执行一些操作的脚本 除了移动鼠标之外 我一切都正常 是否有任何可用的模块可以移动鼠标 适用于 Windows python 3 Thanks I used pynput https pypi or
  • Python:XML 内所有标签名称中的字符串替换(将连字符替换为下划线)

    我有一个格式不太好的 XML 标签名称内有连字符 我想用下划线替换它 以便能够与 lxml objectify 一起使用 我想替换所有标签名称 包括嵌套的子标签 示例 XML
  • 如何在 pygtk 中创建新信号

    我创建了一个 python 对象 但我想在它上面发送信号 我让它继承自 gobject GObject 但似乎没有任何方法可以在我的对象上创建新信号 您还可以在类定义中定义信号 class MyGObjectClass gobject GO
  • 在本地网络上运行 Bokeh 服务器

    我有一个简单的 Bokeh 应用程序 名为app py如下 contents of app py from bokeh client import push session from bokeh embed import server do
  • 如何计算Python中字典中最常见的前10个值

    我对 python 和一般编程都很陌生 所以请友善 我正在尝试分析包含音乐信息的 csv 文件并返回最常听的前 n 个乐队 从下面的代码中 每听一首歌曲都是一个列表中的字典条目 格式如下 album Exile on Main Street
  • cv2.VideoWriter:请求一个元组作为 Size 参数,然后拒绝它

    我正在使用 OpenCV 4 0 和 Python 3 7 创建延时视频 构造 VideoWriter 对象时 文档表示 Size 参数应该是一个元组 当我给它一个元组时 它拒绝它 当我尝试用其他东西替换它时 它不会接受它 因为它说参数不是
  • 使用随机放置的 NaN 创建示例 numpy 数组

    出于测试目的 我想创建一个M by Nnumpy 数组与c随机放置的 NaN import numpy as np M 10 N 5 c 15 A np random randn M N A mask np nan 我在创建时遇到问题mas

随机推荐