实验目的
- 了解机器学习的相关知识
- 实现基于tensorflow的手写数字识别
实验环境
- ubuntu16.04 或 windows
- python 3(默认安装版本)
- tensorflow 2.0 版本以上,或其他深度学习框架
实验内容
实现基于深度学习的 MNIST 手写数字识别
import matplotlib.pyplot as plt
import tensorflow as tf
%matplotlib inline
mnist = tf.keras.datasets.mnist
#加载数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()#(图像,标签)
x_train, x_test = x_train / 255.0, x_test / 255.0#归一化
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),#多维的输入一维化
tf.keras.layers.Dense(128, activation='relu'),#全连接层
tf.keras.layers.Dropout(0.2),#防止过拟合
tf.keras.layers.Dense(10)#输出层
])
predictions = model(x_train[:1]).numpy()
tf.nn.softmax(predictions).numpy()
#损失函数
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss_fn(y_train[:1], predictions).numpy()
#print(model.summary())
#编译步骤
model.compile(optimizer='adam',
loss=loss_fn,
metrics=['accuracy'])
#训练网络,epochs表示多少个训练回合
model.fit(x_train, y_train, epochs=5)
print("测试集结果如下:")
#进行测试集测试 verbose =2:为每个epoch输出一行记录
model.evaluate(x_test, y_test, verbose=2)
plt.figure(figsize=(10, 10))
#前五张测试数据图片
for i in range(5):
plt.subplot(5, 5, i + 1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(x_test[i], cmap=plt.cm.binary)
plt.show()
实验结果:
实验心得
通过Tensorflow框架来建造神经网络模型对mnist手写数据集进行了训练,对测试数据集的估计准确率基本达到95%以上,基本能够识别测试数据集的手写数字图片里的数字信息,较准确的识别出手写数字