TensorFlow实现简单神经网络

2023-10-27

本文首发于我的个人博客QIMING.INFO,转载请带上链接及署名。

在上文(《TensorFlow快速上手》)中,我们介绍了TensorFlow中的一些基本概念,并实现了一个线性回归的例子。

本文我们趁热打铁,接着用TensorFlow实现一下神经网络吧。

TensorFlow中的神经网络可以用来实现回归算法和分类算法,本文将分别给出实现这两种算法的代码。除此之外,还将介绍一个TensorFlow中重要且常用的概念——placeholder(占位符),和一个著名的数据集:MINST数据集。

1 placeholder

在开始之前,先得说一下placeholder,中文翻译为占位符

tensor不仅以常量或变量的形式存储,TensorFlow 还提供了feed机制,该机制可以临时替代计算图中的任意操作中的tensor,可以对图中任何操作提交补丁,直接插入一个tensor。具体方法即使用tf.placeholder()为这些操作创建占位符。简单使用如下:

# 创建input1和input2这两个占位符
input1 = tf.placeholder(tf.float32)
input2 = tf.placeholder(tf.float32)
output = tf.multiply(input1,input2)

with tf.Session() as sess:
    # 通过字典的形式向input1和input2传值
    print(sess.run(output,feed_dict={input1:[7.],input2:[2.]}))

# 输出结果为:[14.]

2 神经网络实现回归算法

2.1 代码及说明

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# 使用numpy生成100个随机点作为假数据
x_data = np.linspace(-0.5,0.5,200)[:,np.newaxis]
noise = np.random.normal(0,0.02,x_data.shape)
y_data = np.square(x_data)+noise

# 定义两个placeholder
x = tf.placeholder(tf.float32,[None,1])
y = tf.placeholder(tf.float32,[None,1])

# 定义神经网络中间层
Weights_L1 = tf.Variable(tf.random_normal([1,10]))
biases_L1 = tf.Variable(tf.zeros([1,10]))
Wx_plus_b_L1 = tf.matmul(x,Weights_L1) + biases_L1
L1 = tf.nn.tanh(Wx_plus_b_L1)

# 定义神经网络输出层
Weights_L2 = tf.Variable(tf.random_normal([10,1]))
biases_L2 = tf.Variable(tf.zeros([1,1]))
Wx_plus_b_L2 = tf.matmul(L1,Weights_L2)+biases_L2
prediction = tf.nn.tanh(Wx_plus_b_L2)

# 二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
# 定义一个梯度下降法来进行训练的优化器 学习率0.1
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

with tf.Session() as sess:
    # 变量初始化
    sess.run(tf.global_variables_initializer())
    # 训练2000次
    for step in range(2000):
        sess.run(train_step,feed_dict={x:x_data,y:y_data})
    # 获得预测值
    prediction_value = sess.run(prediction,feed_dict={x:x_data}) 
    # 画图展示结果
    plt.figure()
    plt.scatter(x_data,y_data)
    plt.plot(x_data,prediction_value,'r-',lw=5)
    plt.show()

2.2 结果

这个神经网络比较简单,使用了tanh()作为激活函数,梯度下降法为优化器,二次代价函数为损失函数。

拟合出的结果如上图红线所示,可以看出,大致是一个二次函数曲线。

3 神经网络实现分类算法

3.1 MNIST数据集简介

MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片,它也包含每一张图片对应的标签,告诉我们这个是数字几。比如,下面这四张图片的标签分别是5,0,4,1。

MNIST数据集有两部分组成:60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)。

每一个MNIST数据单元有两部分组成:一张包含手写数字的图片和一个对应的标签。我们把这些图片设为“xs”,把这些标签设为“ys”。训练数据集和测试数据集都包含xs和ys,比如训练数据集的图片是 mnist.train.images ,训练数据集的标签是 mnist.train.labels

每一张图片包含28像素X28像素。我们可以用一个数字数组来表示这张图片:

我们把这个数组展开成一个向量,长度是 28x28 = 784。因此,在MNIST训练数据集中,mnist.train.images 是一个形状为 [60000, 784] 的张量,第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。相对应的MNIST数据集的标签是介于0到9的数字,用来描述给定图片里表示的数字。为了用于这个教程,我们使标签数据是"one-hot vectors"。 一个one-hot向量除了某一位的数字是1以外其余各维度数字都是0。所以在此教程中,数字n将表示成一个只有在第n维度(从0开始)数字为1的10维向量。比如,标签0将表示成([1,0,0,0,0,0,0,0,0,0,0])。因此, mnist.train.labels 是一个 [60000, 10] 的数字矩阵。

3.2 代码及说明

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)

# 每个批次的大小
batch_size = 100
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size

# 定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])

# 创建一个简单的神经网络(无中间层)
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b)

# 交叉熵
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
# 定义一个梯度下降法来进行训练的优化器 学习率0.2
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()

# 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1)) # argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
    sess.run(init)
    # 训练21轮次
    for epoch in range(21):
        for batch in range(n_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
        # 用测试数据计算模型的准确率
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("Iter "+str(epoch)+",Testing Accuracy "+str(acc))

3.3 结果

Iter 0,Testing Accuracy 0.8488
Iter 1,Testing Accuracy 0.8941
Iter 2,Testing Accuracy 0.9013
Iter 3,Testing Accuracy 0.9053
Iter 4,Testing Accuracy 0.9093
Iter 5,Testing Accuracy 0.91
Iter 6,Testing Accuracy 0.9119
Iter 7,Testing Accuracy 0.914
Iter 8,Testing Accuracy 0.9138
Iter 9,Testing Accuracy 0.916
Iter 10,Testing Accuracy 0.9174
Iter 11,Testing Accuracy 0.9191
Iter 12,Testing Accuracy 0.9184
Iter 13,Testing Accuracy 0.9194
Iter 14,Testing Accuracy 0.9196
Iter 15,Testing Accuracy 0.9203
Iter 16,Testing Accuracy 0.9211
Iter 17,Testing Accuracy 0.9215
Iter 18,Testing Accuracy 0.9211
Iter 19,Testing Accuracy 0.9218
Iter 20,Testing Accuracy 0.9222

本例中神经网络的输出层用了softmax()函数进行分类,损失函数用了交叉熵函数,依旧使用了梯度下降法作为优化器。

结果显示,在训练了21轮后,模型的准确率达到了92.2%,这个准确度不算高,所以还需要进行优化,优化方式下文(TensorFlow进一步优化神经网络)将介绍。

4 小结

在本文中,分别实现了神经网络的回归算法和分类算法,其中提到的有关神经网络的一些概念,如激活函数、损失函数、优化器等,先请读者自行参考相关资料,本人后续可能会补充。

5 参考资料

[1]@Bilibili.深度学习框架Tensorflow学习与应用.2018-03
[2]TensorFlow中文社区.基本用法 | TensorFlow 官方文档中文版

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorFlow实现简单神经网络 的相关文章

随机推荐

  • 实现流程编排设计器的心路历程

    接上回 AntV 使用AntV X6实现流程编排设计器 一文说到 流程编排设计器的实现方案是将低代码引擎和AntV X6作为画布相结合 为什么会有这样的想法 可行性 起因是业务中有用到低代码引擎的场景 它的交互形式 页面结构正好符合流程编排
  • 成为极少数-读后感

    文章目录 自序 真正的成长 都需要你孤独地翻山越岭 觉醒 自我驱动的人是不会焦虑的 上进 只有突破才叫上进 动力 我经常对自己吹牛皮 自律 围绕目标的自我约束 专注 做到勤奋的样子很容易 第二章 方向与精进 思维 建立这三个思维 增加竞争优
  • linux history命令详解

    命令行历史 当执行命令后 系统默认会在内存记录执行过的命令 当用户正常退出时 会将内存的命令历史存放对应历史文件中 默认是 bash history 登录shell时 会读取命令历史文件中记录下的命令加载到内存中 登录进shell后新执行的
  • PC微信逆向:破解聊天记录文件!

    本文转载自程序员专栏 在电子取证过程中 也会遇到提取PC版微信数据的情况 看雪 52破解和CSDN等网上的PC版微信数据库破解文章实在是太简略了 大多数只有结果没有过程 经过反复试验终于成功解密了数据库 现在把详细过程记录下来 希望大家不要
  • PC微信低版本限制登录怎么办?

    前文 最近很多小伙伴遇到了低版本的微信登录时出现 您的微信版本过低 请升级至最新版本微信后在登录微信 点击 确定 后 将跳转至最新版下载页面 或出现未能登录等字样 解决方案 安装最新版本微信 登录一次后 然后在切换低版本微信登录 一般情况下
  • Arduino 连接JDY-08蓝牙模块

    Arduino 连接JDY 08蓝牙模块 文章目录 Arduino 连接JDY 08蓝牙模块 简介 一 基本连接 二 软件连接 三 手机连接 简介 从蓝牙4 0开始包含两个蓝牙芯片模块 传统 经典蓝牙模块 Classic Bluetooth
  • PyQt5中ui文件如何转为Py文件并界面可视化

    1 在pycharm里的File里面找到setting 2 Tools工具里找到External Tools 3 选择 添加 其中Name 根据你自己想法取 这里写的是 Qt Designer Program 这里是找到你的designer
  • 应用程序签名机制

    原文链接 http www 2cto com Article 201308 237263 html Android安全机制分析 Android系统是基于Linux内核开发的 因此 Android系统不仅保留和继承了Linux操作系统的安全机
  • AttributeError: 'module' object has no attribute 的解决方法

    AttributeError module object has no attribute funSalaryGuide 这个错误相信很多django的开发人员都会遇到 一般来说都是应用没有安装完成 重新安装就可以了 这几天我遇到的情况是已
  • 【Xilinx DMA】Xilinx FPGA DMA介绍

    DMA Direct Memory Access 直接内存访问 可以在不受CPU干预的情况下 完成对内存的存取 在PS和PL两端都有DMA 其中PS端的是硬核DMA 而PL端的是软核DMA 如何选用这两个DMA呢 如果从PS端的内存DDR3
  • ###haohaohao###图神经网络之神器——PyTorch Geometric 上手 & 实战

    图神经网络 Graph Neural Networks GNN 最近被视为在图研究等领域一种强有力的方法 跟传统的在欧式空间上的卷积操作类似 GNNs通过对信息的传递 转换和聚合实现特征的提取 这篇博客主要想分享下 怎样在你的项目中简单快速
  • java 对象对象的属性_java中对象属性可以是另外一个对象或对象的参考

    7 对象的属性可以是另外一个对象或对象的参考 通过这种方法可以迅速构建一个比较大的系统 class Motor Light lights Handle left right KickStart ks Motor lights new Lig
  • Xilinx MIPI CSI license

    Xilinx MIPI CSI license 不绑定MAC地址 永久有效 支持所有Vivado版本 技术交流请加 ljy435
  • CENTOS安装curlftpfs

    首先说明 curlftpfs效率还是挺慢的 用于局域网内文件传输会出现不稳定的情况 1 Fedora可以直接yuminstall curlftpfs CentOS不行 得用DAGrepository 所以得先安装DAGrepository
  • Yearning SQL审核平台部署(Yearning-2.3.4-linux-amd64)

    参考博客 https blog csdn net weixin 45858439 article details 105277413 环境 mysql5 7 35 下载Yearning 2 3 5 linux amd64安装包https g
  • PHP7.27: connect mysql 5.7 using new mysqli_connect

  • Jenkins:报错Build step ‘Execute Windows batch command‘ marked build as failure解决办法

    Windows 下本地的 Jenkins 部署完成后 创建任务进行构建时 遇到如下报错信息 构建虽然失败了 但是命令却是执行成功了 问题就在于 Build step Execute Windows batch command marked
  • 数字媒体资产管理教材

    http vr sdu edu cn lulin course DAM
  • 产量预测文献读后整理

    文献名称 1 Data Driven End To End Production Prediction of Oil Reservoirs by EnKF Enhanced Recurrent Neural Networks 2 Produ
  • TensorFlow实现简单神经网络

    本文首发于我的个人博客QIMING INFO 转载请带上链接及署名 在上文 TensorFlow快速上手 中 我们介绍了TensorFlow中的一些基本概念 并实现了一个线性回归的例子 本文我们趁热打铁 接着用TensorFlow实现一下神