一、编写背景
因为需要对接一个官方的编程API,本人需要自己按其要求搭建一个神经网络,以尝试调用某模块的工作。我参考了Tensorflow的参考书了解了MNIST数据集,然后我准备把MNIST数据集转换为图片格式,以适应API的要求。
同样,这个程序转化出的图片格式的MNIST数据集和标签集也非常适合初学者第一次搭建网络。
二、基础依赖
numpy,opencv,原始MNIST数据集
三、程序主体
# mnist数据集请自行下载,本程序默认数据集在./dataset的文件夹下
from tensorflow.examples.tutorials.mnist import input_data
import cv2
import numpy as np
mnist = input_data.read_data_sets("./dataset", one_hot=True)
# 原数据集的训练集有55000个样本,在此只提取10000个,按需更改
IMAGE_NUM = 10000
print('dataset import done')
def image_extract():
for i in range(0, IMAGE_NUM):
# 提取长度784的图片像素向量
img = mnist.train.images[i]
# 转换成28×28的[0,255]的整数矩阵,以方便cv2保存图片
img_re = (img.reshape(-1, 28) * 255).astype(int)
cv2.imwrite('./dataset/images/'+str(i)+'.jpg', img_re )
# print('image ' + str(i) + ' extracted.')
print('images extraction done')
def label_extract():
labels = []
for i in range(0, IMAGE_NUM):
# 提取出长度10的标签向量
lbl = list(mnist.train.labels[i])
# 我以[0-9]的整数进行保存了,实际上用原始的长度10的向量进行训练更合适,自行选择
lbl_num = lbl.index(1)
labels = labels + [lbl_num]
# print('label of image ' + str(i) + ' is ' + str(lbl_num))
# 保存为npy格式更方便读取
np.save('./dataset/label.npy', labels)
print('labels extraction done')
if __name__ == '__main__':
image_extract()
label_extract()