在TensorFlow中使用自定义数据集训练自己的模型

2023-10-27

写在前面的话

    今年电赛终于结束了,身边不少小伙伴都选择了送药小车的题目,刚开始可能都觉得简单吧,循迹小车+数字识别就可以搞定。刚开始很多朋友都考虑使用OpenMv作为数字识别的平台,我手上除了大家津津乐道的OpenMv之外还有国产的K210。老实说就我个人而言我更愿意去使用K210。首先OpenMv作为一个开源项目已经很好的被K210支持,这意味着使用星瞳OpenMv能做的大部分操作K210都可以完成,甚至更好。其次,星瞳OpenMv主要基于ST公司的STM32系列处理器,目前最好的应该是用到了H7系列的片子。尽管它也能运行神经网络,但有一说一它的性能及其有限,一旦运行神经网络之后帧率降得挺严重的,某组小伙伴3 ~ 4帧的帧率。而国产芯片K210不仅价格较之更便宜,其算力更是高达1TOPS,同时硬件KPU支持神经网络常见层,性能强悍!比较有意思的是今年不少小伙伴都打算使用OpenMv作为数字识别的方案,结果纷纷翻车。倒也不是说OpenMv不能做,估计大家不是很熟悉没反应过来吧。星瞳OpenMv神经网络的在线训练平台虽然免费但是对训练时间还是有限制的,手动修改网络层数深一些之后会导致训练线程超时终止,不如本地TensorFlow训练来得实在。追根究底在OpenMv上跑神经网络基本都是tflite模型,TensorFlow本地训练好之后将模型转换好想来在OpenMv上跑也是没啥问题的。由于个人对K210的偏爱电赛期间我的OpenMv基本就放在旁边吃灰了,也看到过有几个小组和我一样使用K210做为数字识别的方案,但大家更多使用的是第三方提供的训练工具或者模型,训练后最终的效果大都不尽人意。影响卷积神经网络结果的原因有很多,模型的结构、优化器函数的选择、数据集的数量和质量等都会影响最终模型的效果。
    本文仅介绍如何在TensorFlow上使用自己的数据集训练自己的模型,关于CNN模型的构建、调优以及转换为K210支持的kmodel格式等问题均不在本文讨论范围内,有兴趣的小伙伴可自行查阅相关文档~~

一、自定义数据集的目录结构(以今年电赛数字识别为例)

在这里插入图片描述
其中 Training 是训练集, Validation 是验证集

各类中的部分样本:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

上面的数据集比较混乱,原因是刚开始我打算直接使用mnist数据集进行训练,结果识别效果并不是很好。重新制作数据集训练的时候直接使用的是前面训练的模型,虽然题目要求识别的数字只有 1~809 仍被保留了下来。

具体步骤:
    1.新建一个用于存放数据集的目录 dataset
    2.在 dataset 目录中创建 Training 目录用于存放训练集图片, Validation 目录则存放验证集图片(如果需要验证集的话)
    3.Training 目录与 Validation 目录中每个文件夹的名字就是一个标签 (Training目录与Validation目录中文件夹的数量及名称应保持一致) 。如上图中Training目录下的文件夹 0~9,分别对应mnist数据集中的 label 0~9
    4.dataset/Training/*目录下存放的就是训练集中对应该标签的图像数据(如上图中 各类中的部分样本 所示)

至此,数据集的目录已经构建完成。下一步则需要在TensorFlow中读入数据集的目录结构并解析、转换为tensor的数据类型以满足TensorFlow后面自动推理的要求


二、在TensorFlow中读取数据集

一般情况下,若使用TensorFlow自带的数据集可通过如下方式加载:

import tensorflow as tf

(train_data,train_label),(test_data,test_label) = tf.keras.datasets.mnist.load_data()

对于自定义数据集而言,我们有多种导入的方式,下文中将介绍如下两种:
    2.1 纯手工打造
    2.2 利用TensorFlow keras ImageDataGenerator

2.1 纯手工打造

先介绍第一种纯手工打造的方式。这种方式可以让你对整个数据集的读取、处理流程更清晰,灵活性更好。我们则需要手动去实现中间的过程,首先导入相关的包:

import tensorflow as tf
import pathlib

数据集的目录结构参照上文所述制作即可。自定义数据集的载入流程可以分成如下几个步骤:
    1.获取所有图片的路径
    2.获取标签并转换为数字
    3.读取图片并进行相应的预处理
    4.打包图片与标签

下面将以在TensorFlow中导入训练集 (dataset/Training) 为例进行说明,验证集 Validation 和测试集导入方式同理,下文将不再进行说明。

2.1.1 获取所有图片的路径
# 指定训练集数据的路径
my_dataset_path = 'dataset/Training'
# 指定图像要调整的大小,图像大小应与模型输入层保持一致
my_image_size = (32,32)
# 指定图像维度    1,单通道(例如灰度图);3,三通道(彩图)
my_input_shape = my_image_size + (3,)
# 指定batch
my_batch = 32
# shuffle buffer size
my_shuffle_buffer_size = 1000

AUTOTUNE = tf.data.experimental.AUTOTUNE

# 获取所有图像文件的路径
dataset_path = pathlib.Path(my_dataset_path)
all_images_paths = [str(path) for path in list(dataset_path.glob('*/*'))]
print('所有文件的路径:', all_images_paths)
print('文件总数:', len(all_images_paths))

输出结果如下:

所有文件的路径: ['dataset\\Training\\0\\1.jpg', 'dataset\\Training\\0\\21.jpg', 'dataset\\Training\\1\\00000.jpg', 'dataset\\Training\\1\\00001.jpg', 'dataset\\Training\\2\\00000.jpg', 'dataset\\Training\\2\\00001.jpg', 'dataset\\Training\\3\\00027.jpg', 'dataset\\Training\\3\\00028.jpg', 'dataset\\Training\\4\\00000.jpg', 'dataset\\Training\\4\\00001.jpg', 'dataset\\Training\\5\\00161.jpg', 'dataset\\Training\\5\\00162.jpg', 'dataset\\Training\\6\\00001.jpg', 'dataset\\Training\\6\\00002.jpg', 'dataset\\Training\\7\\00163.jpg', 'dataset\\Training\\7\\00164.jpg', 'dataset\\Training\\8\\00000.jpg', 'dataset\\Training\\8\\00001.jpg', 'dataset\\Training\\9\\19.jpg', 'dataset\\Training\\9\\4.jpg']

此时,all_images_paths 列表中存储了训练数据集中所有图片的路径,在后面我们需要通过这些路径来将图片读入到内存中

2.1.2 获取标签并转换为数字

在目标检测的任务中通常会有多个标签 (label) ,为了方便人们区分,这些标签通常会使用字符串来进行描述。但是在TensorFlow进行推理的过程中是无法是用文本类型的标签的,我们需要将其转换成数字。
首先,我们需要通过解析图片路径 (all_images_paths) 的上层目录来获取标签:

# 获取标签名称
label_name = [i.name for i in dataset_path.iterdir() if i.is_dir()]
print('标签名称:', label_name)

输出结果如下:

标签名称: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

可以看到这些标签正好就是我们目录的名字,这证明我们程序的解析是正确的。接着,我们需要为这些标签分配一个唯一的数字,用这个数字来代替标签的名字:

# 因为训练时参数必须为数字,因此为标签分配数字索引
label_index = dict((name,index)for index,name in enumerate(label_name))
print('为标签分配数字索引:', label_index)

输出结果如下:

为标签分配数字索引: {'9': 9, '1': 1, '6': 6, '5': 5, '2': 2, '0': 0, '3': 3, '4': 4, '7': 7, '8': 8}

然后我们将图片与标签的数字索引进行配对,务必保证 all_images_paths 列表中图像数据的标签必须与数字索引一致:

# 将图片与标签的数字索引进行配对(number encodeing)
number_encodeing = [label_index[i.split('\\')[2]]for i in all_images_paths]
print('number_encodeing:', number_encodeing, type(number_encodeing))

输出结果如下:

number_encodeing: [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9] <class 'list'>

为了方便截图我删除了大量的样本,此处我们数据集dataset/Training中每一类都有两个图片样本,因此读取完成后的结果与上面的输出结果是一致的
在这里插入图片描述

为方便训练,多数情况下你还需要对标签做One Hot变换:

label_one_hot = tf.keras.utils.to_categorical(number_encodeing)
print(label_one_hot)

输出结果如下:

[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
2.1.3 读取图片并进行相应的预处理

在上面的步骤中我们完成了对数据集标签的处理,现在我们需要从磁盘中载入图像并在正式开始训练之前对图像进行相应的预处理:

def process(path, label):
    # 读入图片文件
    image = tf.io.read_file(path)
    # 将输入的图片解码为gray或者rgb
    image = tf.image.decode_jpeg(image, channels=my_input_shape[2])
    # 调整图片尺寸以满足网络输入层的要求
    image = tf.image.resize(image, my_image_size)
    # 归一化
    image /= 255.
    return image,label

process方法接收一个路径 (还有一个标签,为了方便后面打包),在该方法的内部使用TensorFlow的一些方法完成对图像的读取,同时调整图像的尺寸和维度以便满足模型的需要。最后对调整好的图像数据进行归一化处理,一般情况下我们将范围映射到 0~1 的区间,当然你也可以映射到 -0.5~0.5 等区间,根据自己的需要去设计即可。一个好的区间可以让模型更好的收敛同时精度也得以提升。
使用matplotlib.pyplot查看图像:

import matplotlib.pyplot as plt

img = process('dataset\\Training\\1\\00001.jpg', 1)
# 或
img = process(all_images_paths[2], 1)

print(img[0].shape)
plt.imshow(img[0])
plt.show()

输出结果:

(32, 32, 3)

在这里插入图片描述

从输出的结果中可以看到图像的尺寸已经调整到与我们设置的 my_imput_shape=(32,32,3) 一致

2.1.4 打包图片与标签
# 将数据与标签拼接到一起
label_one_hot = tf.cast(label_one_hot, tf.int32)
path_ds = tf.data.Dataset.from_tensor_slices((all_images_paths, label_one_hot))
image_label_ds = path_ds.map(process, num_parallel_calls=AUTOTUNE)
print('image_label_ds:', image_label_ds)

输出结果:

image_label_ds: <ZipDataset shapes: ((32, 32, 3), (10,)), types: (tf.float32, tf.int32)>

tf.data.Dataset.from_tensor_slices 的作用是对数据集进行切片。上文中我们将所有图片的路径 all_images_paths 以及 label_one_hot 传入,经 map() 后得到一个二元组,其中索引0存放的是图像数据,索引1存放的则是其label对应的One Hot编码:

# 以第3个数据为例,输出其结果
res = [i for i in image_label_ds.take(3)][-1]
print(res[0])
print(res[1])

# 显示图片
plt.imshow(res[0])
plt.show()

输出结果:

tf.Tensor(
[[[0.48088235 0.54362744 0.49803922]
  [0.5034314  0.5495098  0.5122549 ]
  [0.5264706  0.56960785 0.5382353 ]
  ...
  [0.54068625 0.60343134 0.5602941 ]
  [0.5112745  0.59656864 0.5254902 ]
  [0.4970588  0.58137256 0.53137255]]

 [[0.46813726 0.53088236 0.49019608]
  [0.50980395 0.5529412  0.52156866]
  [0.5137255  0.5568628  0.5254902 ]
  ...
  [0.5352941  0.5980392  0.55490196]
  [0.5132353  0.57598037 0.5328431 ]
  [0.5156863  0.5745098  0.5470588 ]]

 [[0.49509802 0.5421569  0.5029412 ]
  [0.5107843  0.55784315 0.51862746]
  [0.5323529  0.5715686  0.54019606]
  ...
  [0.5387255  0.5857843  0.54656863]
  [0.50980395 0.5715686  0.5205882 ]
  [0.50735295 0.57009804 0.53088236]]

 ...

 [[0.38333333 0.44607842 0.40686274]
  [0.42009804 0.4632353  0.43186274]
  [0.43333334 0.4764706  0.46078432]
  ...
  [0.422549   0.50490195 0.44607842]
  [0.4240196  0.5181373  0.44558823]
  [0.41813725 0.5122549  0.45735294]]

 [[0.3882353  0.4392157  0.4       ]
  [0.40882352 0.4598039  0.42058823]
  [0.42941177 0.48039216 0.44509804]
  ...
  [0.44166666 0.5122549  0.45735294]
  [0.42647058 0.5088235  0.44215685]
  [0.39950982 0.4779412  0.43088236]]

 [[0.36862746 0.44509804 0.39803922]
  [0.3882353  0.46470588 0.41764706]
  [0.40441176 0.47990197 0.4377451 ]
  ...
  [0.43186274 0.502451   0.44754902]
  [0.41127452 0.49362746 0.4269608 ]
  [0.40686274 0.4852941  0.43823528]]], shape=(32, 32, 3), dtype=float32)
tf.Tensor([0 1 0 0 0 0 0 0 0 0], shape=(10,), dtype=int32)

在这里插入图片描述

为防止过拟合增强模型的泛化能力,我们还需要将数据集中的顺序打乱,先来看看打乱之前图片的顺序:

def display_more_image(image_label_ds, s_pos, e_pos, max_r, max_c):
    if e_pos >= s_pos:
        index = 1
        plt.figure()
        for n,image in enumerate(image_label_ds.take(e_pos)):
            if n >= s_pos-1:
                img = image[0]
                label = image[1]
                plt.subplot(max_r, max_c, index)
                index += 1
                plt.imshow(img)
                plt.xlabel(str(list(label.numpy()).index(max(label.numpy()))))
        plt.show()

display_more_image(image_label_ds, 1, 6, 2, 3)

输出结果:
在这里插入图片描述

现在打乱顺序

# 打乱dataset中的元素并设置batch
image_label_ds = image_label_ds.shuffle(my_shuffle_buffer_size).batch(my_batch)

显示打乱顺序后的图片:

# 显示第一个bath的前六张图片
index = 1
plt.figure()
for i in image_label_ds.take(1):
    for j in range(6):
        plt.subplot(2, 3, index)
        index += 1
        # print(i[1][j])
        plt.imshow(i[0][j])
        plt.xlabel(str(list(i[1][j].numpy()).index(max(i[1][j].numpy()))))

# 务必等上面的for执行完成后再调用显示
plt.show()

执行结果:
在这里插入图片描述

可见使用shuffle打乱顺序后的输出与打乱之前的图片顺序是不一致的!

至此,TensorFlow加载自定义数据集的操作就完成了,之后就是构建网络模型以及训练。


2.2 利用TensorFlow keras ImageDataGenerator

这种导入数据的方式较之上面那种更简单、更优雅,主要使用keras ImageDataGenerator,你只需要几行代码即可完成数据集的导入:

my_image_size = (32,32)
my_input_shape = my_image_size + (3,)
# 指定训练次数
my_train_epochs = 2
# 指定batch
my_batch = 32
# 创建
两个数据生成器,指定scaling范围0~1
train_datagen = ImageDataGenerator(rescale=1/255)
# validation_datagen = ImageDataGenerator(rescale=1/255)

# 指向训练数据文件夹
train_generator = train_datagen.flow_from_directory(
    './dataset/Training',           # 训练数据所在文件夹
    target_size=my_image_size,         # 指定输出尺寸
    batch_size=my_batch,
    color_mode=  'rgb' if my_input_shape[2]==3 else 'grayscale',
    class_mode='categorical')            # 指定分类# 创建两个数据生成器,指定scaling范围0~1
train_datagen = ImageDataGenerator(rescale=1/255)
# validation_datagen = ImageDataGenerator(rescale=1/255)

# 指向验证数据文件夹
# validation_generator = validation_datagen.flow_from_directory(
#     './dataset/Validation',
#     target_size=my_image_size,
#     batch_size=my_batch,
#     color_mode=  'rgb' if my_input_shape[2]==3 else 'grayscale',
#     class_mode='categorical')

三、验证自定义数据集是否可用

下面以我今年电赛中数字识别的网络模型为例测试上文中自定义的数据集是否可用,关于模型的构建不是本文重点,有兴趣的小伙伴可自行查阅相关资料,下文直接给出完整的测试用例:

3.1 验证纯手工方式加载的数据

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# Python version: 3.5.4
'''
@File    :   tf_data_generator.py
@Time    :   2021/11/04 16:11:19
@Author  :   Wu Xueru 
@Version :   1.0
@Contact :   t01051@163.com
@License :   
@Desc    :   
'''

# here put the import lib
import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt
from time import strftime


my_dataset_path = 'dataset/Training'
my_image_size = (32,32)
my_input_shape = my_image_size + (3,)
# 指定训练次数
my_train_epochs = 2
# 指定batch
my_batch = 32
# shuffle buffer size
my_shuffle_buffer_size = 1000

AUTOTUNE = tf.data.experimental.AUTOTUNE

# 获取所有文件路径
dataset_path = pathlib.Path(my_dataset_path)
all_images_paths = [str(path) for path in list(dataset_path.glob('*/*'))]
print('所有文件的路径:', all_images_paths)
print('文件总数:', len(all_images_paths))

# 获取标签名称
label_name = [i.name for i in dataset_path.iterdir() if i.is_dir()]
print('标签名称:', label_name)
# 因为训练时参数必须为数字,因此为标签分配数字索引
label_index = dict((name,index)for index,name in enumerate(label_name))
print('为标签分配数字索引:', label_index)

# 将图片与标签的数字索引进行配对(number encodeing)
number_encodeing = [label_index[i.split('\\')[2]]for i in all_images_paths]
print('number_encodeing:', number_encodeing, type(number_encodeing))
label_one_hot = tf.keras.utils.to_categorical(number_encodeing, num_classes=10)
print('label_one_hot:', label_one_hot)


def process(path,label):
    # 读入图片文件
    image = tf.io.read_file(path)
    # 将输入的图片解码为gray或者rgb
    image = tf.image.decode_jpeg(image, channels=my_input_shape[2])
    # 调整图片尺寸以满足网络输入层的要求
    image = tf.image.resize(image, my_image_size)
    # 归一化
    image /= 255.
    return image,label

# 将数据与标签拼接到一起
path_ds = tf.data.Dataset.from_tensor_slices((all_images_paths, tf.cast(label_one_hot, tf.int32)))
image_label_ds = path_ds.map(process, num_parallel_calls=AUTOTUNE)
print('image_label_ds:', image_label_ds)
steps_per_epoch=tf.math.ceil(len(all_images_paths)/my_batch).numpy()
print('steps_per_epoch', steps_per_epoch)

# 打乱dataset中的元素并设置batch
image_label_ds = image_label_ds.shuffle(my_shuffle_buffer_size).batch(my_batch)


if __name__ == '__main__':
    # 定义模型
    # 输入层
    input_data = tf.keras.layers.Input(shape=my_input_shape)
    # 第一层
    middle = tf.keras.layers.Conv2D(128, kernel_size=[3,3], strides=(1,1), padding='same', activation=tf.nn.relu)(input_data)
    middle = tf.keras.layers.Conv2D(128, kernel_size=[3,3], strides=(1,1), padding='same', activation=tf.nn.relu)(middle)
    middle = tf.keras.layers.Conv2D(128, kernel_size=[3,3], strides=(1,1), padding='same', activation=tf.nn.relu)(middle)
    middle = tf.keras.layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same')(middle)
    # 第二层
    middle = tf.keras.layers.Conv2D(128, kernel_size=[3,3], strides=(1,1), padding='same', activation=tf.nn.relu)(middle)
    middle = tf.keras.layers.Conv2D(128, kernel_size=[3,3], strides=(1,1), padding='same', activation=tf.nn.relu)(middle)
    middle = tf.keras.layers.Conv2D(128, kernel_size=[3,3], strides=(1,1), padding='same', activation=tf.nn.relu)(middle)
    middle = tf.keras.layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same')(middle)
    # 第三层
    middle = tf.keras.layers.Conv2D(128, kernel_size=[3,3], strides=(1,1), padding='same', activation=tf.nn.relu)(middle)
    middle = tf.keras.layers.Conv2D(128, kernel_size=[3,3], strides=(1,1), padding='same', activation=tf.nn.relu)(middle)
    middle = tf.keras.layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same')(middle)

    # 铺平
    dense = tf.keras.layers.Flatten()(middle)
    dense = tf.keras.layers.Dropout(0.1)(dense)
    dense = tf.keras.layers.Dense(60, activation='relu')(dense)
    # 输出
    # 输出层
    output_data = tf.keras.layers.Dense(len(label_name), activation='softmax')(dense)
    # 确认输入位置和输出位置
    model = tf.keras.Model(inputs=input_data, outputs=output_data)

    # 定义模型的梯度下降和损失函数
    model.compile(optimizer=tf.optimizers.Adam(1e-4), 
                loss=tf.losses.categorical_crossentropy,
                metrics=['accuracy'])

    # 打印模型结构
    model.summary()

    # 开始训练
    start_time = strftime("%Y-%m-%d %H:%M:%S")
    history = model.fit(
        image_label_ds,
        epochs=my_train_epochs,
        verbose=1,
        steps_per_epoch=int(steps_per_epoch))

    end_time = strftime("%Y-%m-%d %H:%M:%S")
    print('开始训练的时间:', start_time)
    print('结束训练的时间:', end_time)

输出结果:
在这里插入图片描述

从输出中可以看到自定义的数据集已经被正确读取,TensorFlow能够正常进行推理。

3.1 验证利用TensorFlow keras ImageDataGenerator方式加载的数据

# here put the import lib
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
from time import strftime
from os import path

my_image_size = (32,32)
my_input_shape = my_image_size + (3,)
# 指定训练次数
my_train_epochs = 2
# 指定batch
my_batch = 32
# shuffle buffer size
my_shuffle_buffer_size = 1000
save_model_path = path.dirname(path.abspath(__file__))

# 创建两个数据生成器,指定scaling范围0~1
train_datagen = ImageDataGenerator(rescale=1/255)
# validation_datagen = ImageDataGenerator(rescale=1/255)

# 指向训练数据文件夹
train_generator = train_datagen.flow_from_directory(
    './dataset/Training',           # 训练数据所在文件夹
    target_size=my_image_size,         # 指定输出尺寸
    batch_size=my_batch,
    color_mode=  'rgb' if my_input_shape[2]==3 else 'grayscale',
    class_mode='categorical')            # 指定分类

# 指向验证数据文件夹
# validation_generator = validation_datagen.flow_from_directory(
#     './dataset/Validation',
#     target_size=my_image_size,
#     batch_size=my_batch,
#     color_mode=  'rgb' if my_input_shape[2]==3 else 'grayscale',
#     class_mode='categorical')

if __name__ == '__main__':
    # 定义模型
    # 输入层
    input_data = tf.keras.layers.Input(shape=my_input_shape)
    # 第一次卷积
    middle = tf.keras.layers.Conv2D(128, kernel_size=[3,3], strides=(1,1), padding='same', activation=tf.nn.relu)(input_data)
    middle = tf.keras.layers.Conv2D(128, kernel_size=[3,3], strides=(1,1), padding='same', activation=tf.nn.relu)(middle)
    middle = tf.keras.layers.Conv2D(128, kernel_size=[3,3], strides=(1,1), padding='same', activation=tf.nn.relu)(middle)
    middle = tf.keras.layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same')(middle)
    # 第二次卷积
    middle = tf.keras.layers.Conv2D(128, kernel_size=[3,3], strides=(1,1), padding='same', activation=tf.nn.relu)(middle)
    middle = tf.keras.layers.Conv2D(128, kernel_size=[3,3], strides=(1,1), padding='same', activation=tf.nn.relu)(middle)
    middle = tf.keras.layers.Conv2D(128, kernel_size=[3,3], strides=(1,1), padding='same', activation=tf.nn.relu)(middle)
    middle = tf.keras.layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same')(middle)
    # 第三次卷积
    middle = tf.keras.layers.Conv2D(128, kernel_size=[3,3], strides=(1,1), padding='same', activation=tf.nn.relu)(middle)
    middle = tf.keras.layers.Conv2D(128, kernel_size=[3,3], strides=(1,1), padding='same', activation=tf.nn.relu)(middle)
    middle = tf.keras.layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same')(middle)

    # 铺平
    dense = tf.keras.layers.Flatten()(middle)
    dense = tf.keras.layers.Dropout(0.1)(dense)
    dense = tf.keras.layers.Dense(60, activation='relu')(dense)
    # 输出
    # 输出层
    output_data = tf.keras.layers.Dense(10, activation='softmax')(dense)
    # 确认输入位置和输出位置
    model = tf.keras.Model(inputs=input_data, outputs=output_data)

    # 定义模型的梯度下降和损失函数
    model.compile(optimizer=tf.optimizers.Adam(1e-4), 
                loss=tf.losses.categorical_crossentropy,
                metrics=['accuracy'])

    # 打印模型结构
    model.summary()

    # 开始训练
    start_time = strftime("%Y-%m-%d %H:%M:%S")
    history = model.fit(
        train_generator,
        epochs=my_train_epochs,
        verbose=1)

    end_time = strftime("%Y-%m-%d %H:%M:%S")
    print('开始训练的时间:', start_time)
    print('结束训练的时间:', end_time)

在这里插入图片描述

同样OK!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在TensorFlow中使用自定义数据集训练自己的模型 的相关文章

  • 最近的 AWS 区域的客户端 IP 地址

    Question 我想从客户端设备将一些数据上传到 AWS 但我想上传到最近的 AWS 区域的 S3 存储桶 同样 我希望能够从最近的区域下载 当然 我会在每个区域设置一个存储桶 我可以使用一个系统 它可以获取客户端的 IP 地址 然后确定
  • 计时器显示负的已用时间

    我正在使用一个非常简单的代码来计算每个循环的时间for陈述 它看起来像这样 import time for item in list of files Start timing this loop start time clock Do a
  • 将 SQLite 的 FTS3/4 与 Python 3 结合使用

    我一直在使用 python 的 Flask 框架开发 peewee 的示例博客应用程序 看https github com coleifer peewee https github com coleifer peewee 内部示例 gt 博
  • Python 包?

    好吧 我认为无论我做错了什么 它可能都是显而易见的 但我无法弄清楚 我已经阅读并重新阅读了有关包的教程部分 我唯一能想到的是这不起作用 因为我直接执行它 这是目录设置 eulerproject init py euler1 py euler
  • Ttk Treeview:跟踪键盘选择

    这是一个带有 ttk 树视图的 Tk 小部件 当用户单击该行时 会执行某些功能 此处仅打印项目文本 我需要的是以下内容 最初的重点是文本输入 当用户按下 Tab 键时 焦点应该转到第一行 并且应该执行绑定到 Click 事件的函数 当用户使
  • 使用 theano 进行多处理

    我正在尝试将 theano 与 cpu 多处理和神经网络库 Keras 结合使用 I use device gpu标记并加载 keras 模型 然后 为了提取超过一百万张图像的特征 我使用多处理池 该函数看起来像这样 from keras
  • 在 Django 中上传文件

    我在 Django 1 6 版本 中上传文件时遇到问题 当我尝试做的时候new file data save 在我的views py 中我收到此错误 quiz patent 22 medical record 2 exams 处的属性错误
  • 从 python 的单词列表中查找最长的常见单词序列

    我搜索了很多解决方案 确实发现了类似的问题 这个答案 https stackoverflow com questions 21930757 longest repeated substring返回可能不属于输入列表中所有字符串的最长字符序列
  • 为什么 1.__add__(2) 不起作用? [复制]

    这个问题已经存在了 可能的重复 访问 python int 文字方法 https stackoverflow com questions 10955703 accessing a python int literals methods 在P
  • 在 Qt Creator 中相互公开 QML 组件

    我正在使用 Qt Quick 和 PySide2 开发仪表板应用程序 但在 Qt Creator 的设计模式中公开我的 QML 组件时遇到问题 我的文件夹结构如下所示 myapp mycomponents component1 qml co
  • Python Pandas:将参数传递给 agg() 中的函数

    我试图通过使用不同类型的函数和参数值来减少 pandas 数据框中的数据 但是 我无法更改聚合函数中的默认参数 这是一个例子 gt gt gt df pd DataFrame x 1 np nan 2 1 y a a b b gt gt g
  • 如何检查两个数据集的匹配列之间的相关性?

    如果我们有数据集 import pandas as pd a pd DataFrame A 34 12 78 84 26 B 54 87 35 25 82 C 56 78 0 14 13 D 0 23 72 56 14 E 78 12 31
  • 在Python中使用Counter()来构建直方图?

    我在另一个问题上看到我可以使用Counter 计算一组字符串中出现的次数 所以如果我有 A B A C A A I get Counter A 3 B 1 C 1 但现在 我如何使用该信息来构建直方图 对于您的数据 最好使用条形图而不是直方
  • Python:多重分配与单独分配速度

    我一直在寻求从我的代码中挤出更多的性能 最近 在浏览时这个 Python 维基页面 https wiki python org moin PythonSpeed 我发现了这个说法 多重分配比单独分配慢 例如 x y a b 比 x a y
  • 我的 R 平方分数为负,但使用 k 倍交叉验证的准确度分数约为 92%

    对于下面的代码 我的 r 平方分数为负 但使用 k 折交叉验证的准确度分数为 92 这怎么可能 我使用随机森林回归算法来预测一些数据 数据集的链接在下面的链接中给出 https www kaggle com ludobenistant hr
  • 在 Mac OS x 10.7.5 中运行 Scrapy 所需的文件,使用 Python 2.7.3 IEPD_free(32 位)

    我是第一次测试 scrapy 使用命令安装后 sudo easy install U scrapy 一切似乎都运行正常 但是 当我运行时 scrapy startproject tutorial 我得到以下信息 luismacbookpro
  • Twitter 不再使用请求库 python

    我有一个 python 函数 它使用 requests 库和 BeautifulSoup 来抓取特定用户的推文 import requests from bs4 import BeautifulSoup contents requests
  • 按工作日分组的熊猫 (M/T/W/T/F/S/S)

    我有一个 pandas 数据框 其中包含 YYYY MM DD arrival date 形式的时间序列 作为索引 我想按每个工作日 周一到周日 进行分组 以便计算其他日期列是平均值 中位数 标准差等 我最终应该只有七行 到目前为止我只知道
  • Pygame 文本不渲染

    好的 我正在用 python 和 pygame 制作一个多项选择测验游戏 不过 我已经完成了开始屏幕并尝试制作问题屏幕 我根本不明白为什么文本不呈现 这是我的代码 enter pressed False random question ra
  • 仅在满足条件时添加到字典

    我在用urllib urlencode构建 Web POST 参数 但是有一些值我只想在除None为他们而存在 apple green orange orange params urllib urlencode apple apple or

随机推荐

  • 自定义指令 v-loading

    1 在src下创建directive文件夹 2 在directive文件夹下创建loading文件夹 3 loading文件夹内创建index js和loading vue 目录图 4 index js src directive load
  • QtCreator 快捷键问题记录

    我目前用的QtCreator Mac版8 0 0 具体信息如下 一般来说QtCreator的快捷键和设置项在windows下也是一样的 在QtCreator gt Options gt Environment gt Keyboard中可以找
  • SpringBoot整合office转换与预览

    文章目录 一 介绍 1 简介 2 aspose简介 3 jodconverter简介 二 springboot整合aspose实战 1 前期依赖准备 1 1 介绍 1 2 项目直接引入jar包 1 3 maven添加本地包 2 office
  • 使用Retrofit上传实体类到服务端(笔记)

    一 服务端 1 需要对参数用 RequestBody这个注解进行修饰 SpringBoot会自动将前端传过来的JSON数据反序列化成Java对象 登录 param requestVo return PostMapping value log
  • DOTA数据集标签txt文件转为xml文件

    文章目录 1 txt文件格式 2 xml文件格式 3 一般的txt到xml的转换思路 4 最终使用的txt到xml转换的脚本 5 之后可能用到的xml转换到txt的脚本 1 txt文件格式 DOTA数据集的txt文件格式如下 其中 每一行的
  • Springboot整合SpringSecurity

    使用Basic认证模式 1 maven依赖
  • 26.JavaWeb-SpringSecurity安全框架

    1 SpringSecurity安全框架 Spring Security是一个功能强大且灵活的安全框架 它专注于为Java应用程序提供身份验证 Authentication 授权 Authorization 和其他安全功能 Spring S
  • csv反序列化_序列化与反序列化

    toc 定义 序列化 将对象或数据结构转换成约定格式数据的过程 反序列化 将约定格式的数据转换成对象或数据结构的过程 通常我们将这种 约定格式的数据 称之为序列化协议 根据协议的特点序列化协议可以细分为文本序列化协议 以下简称文本协议 和二
  • 网线直连NUC调试并使用VSCode实现X11转发(Jetson,树莓派适用)

    1 场景描述 此种场景下 NUC与PC机通过一根网线进行连接 网线负责PC与NUC进行通信 SSH连接 同时可以将NUC的图形界面转发到PC 远程桌面或X11窗口转发均可 方便战队成员在没有显示器的场景下对NUC进行调试 配置示例如下图所示
  • springboot微服务前端传参数至后端的几个方式,@RequestBody如何传入多个参数@RequestParam

    一 问题 在接口测试工具中 常常要传入参数 初学者 也就是我菜鸡经常传错参数 不明白在Query还是Body里面传参 以及测试工具 AxxPoxx 测试下载接口的时候为什么发送数据成功却没有下载文档下来 后端参数传输方式 二 解决 quer
  • JNDI 资源

    第 6 章 JNDI 资源 Java 命名和目录接口 Java Naming and Directory Interface JNDI 是一种应用编程接口 application programming interface API 用于访问
  • "NO 3D support is available from the host"

    https forums opensuse org showthread php 494522 No 3d Support or graphics accelleration http askubuntu com questions 537
  • 【项目设计】高并发内存池 (四)[pagecache实现]

    C 学习历程 入门 博客主页 一起去看日落吗 持续分享博主的C 学习历程 博主的能力有限 出现错误希望大家不吝赐教 分享给大家一句我很喜欢的话 也许你现在做的事情 暂时看不到成果 但不要忘记 树 成长之前也要扎根 也要在漫长的时光 中沉淀养
  • 解决出现“raw.githubusercontent.com (raw.githubusercontent.com)

    服务器安装软件的时候出现 正在连接 raw githubusercontent com raw githubusercontent com 0 0 0 0 443 失败 拒绝连接 是因为改网址是被墙的 但是还是需要安装软件怎么办 打开多个地
  • 固态U盘量产:群联PS3111主控开卡量产工具使用教程

    PS3111开卡量产工具是一款专门用来进行量产的软件工具 下面将为大家提供使用教程 以帮助大家更加顺利地进行U盘量产 1 下载PS3111开卡量产工具 首先 需要在量产部落官网下载该工具并解压到电脑上 2 连接U盘 将需要进行量产的固态U盘
  • C++ opencv视频处理与保存

    1 视频属性类型 视频有很多的属性 有时长 分辨率 帧宽度 帧高度 帧速率等 视频属性中 由于国内互联网视频网站的定义 我们对分辨率的区分有些误区 所以这里重新介绍一下视频的分辨率 至于其他属性 一般不会有什么误区 分辨率 通常国际标准 我
  • 大语言模型高质量提示词最佳实践

    大语言模型高质量提示词最佳实践 一 提供更清晰的指令 使用大语言模型 类似ChatGPT Bard等工具 的过程中 一个关键的技巧是能够给出清晰和明确的指令 大语言模型的运作方式是根据提供的输入 预测接下来应该生成什么内容 因此 给出明确的
  • Vue的UI网页创建和引入Element组件

    第一步 新创建一个文件夹用来生成vue项目 第二步 进入这个文件夹 在路径栏输入CMD打开DOS窗口 第三步 打开cmd窗口 输入命令 vue ui 第四步 输入命令后运行 浏览器会自动打开vue ui 网页 第五步 点击仪表盘上方的文本框
  • Keepalived--05--脑裂问题

    一 问题 1 1 场景 高可用 在高可用 HA 系统中 当联系2个节点的 心跳线 断开时 本来为一整体 动作协调的HA系统 就分裂成为2个独立的个体 由于相互失去了联系 都以为是对方出了故障 两个节点上的HA软件像 裂脑人 一样 争抢 共享
  • 在TensorFlow中使用自定义数据集训练自己的模型

    在TensorFlow中使用自定义数据集训练自己的模型 写在前面的话 一 自定义数据集的目录结构 以今年电赛数字识别为例 二 在TensorFlow中读取数据集 2 1 纯手工打造 2 1 1 获取所有图片的路径 2 1 2 获取标签并转换