我是 Tensorflow 和深度学习的新手,并且在 Dataset 类上遇到了困难。我尝试了很多方法,但找不到好的解决方案。
我正在尝试什么
我有大量图像 (500k+) 来训练我的 DNN。这是一个去噪自动编码器,所以我每个图像都有一对。我正在使用TF的数据集类来管理数据,但我认为我用得非常糟糕。
以下是我在数据集中加载文件名的方法:
class Data:
def __init__(self, in_path, out_path):
self.nb_images = 512
self.test_ratio = 0.2
self.batch_size = 8
# load filenames in input and outputs
inputs, outputs, self.nb_images = self._load_data_pair_paths(in_path, out_path, self.nb_images)
self.size_training = self.nb_images - int(self.nb_images * self.test_ratio)
self.size_test = int(self.nb_images * self.test_ratio)
# split arrays in training / validation
test_data_in, training_data_in = self._split_test_data(inputs, self.test_ratio)
test_data_out, training_data_out = self._split_test_data(outputs, self.test_ratio)
# transform array to tf.data.Dataset
self.train_dataset = tf.data.Dataset.from_tensor_slices((training_data_in, training_data_out))
self.test_dataset = tf.data.Dataset.from_tensor_slices((test_data_in, test_data_out))
我有一个函数可以在每个时期调用来准备数据集。它会打乱文件名,并将文件名转换为图像和批处理数据。
def get_batched_data(self, seed, batch_size):
nb_batch = int(self.size_training / batch_size)
def img_to_tensor(path_in, path_out):
img_string_in = tf.read_file(path_in)
img_string_out = tf.read_file(path_out)
im_in = tf.image.decode_jpeg(img_string_in, channels=1)
im_out = tf.image.decode_jpeg(img_string_out, channels=1)
return im_in, im_out
t_datas = self.train_dataset.shuffle(self.size_training, seed=seed)
t_datas = t_datas.map(img_to_tensor)
t_datas = t_datas.batch(batch_size)
return t_datas
现在在训练期间,在每个时期我们称get_batched_data
函数,创建一个迭代器,并为每个批次运行它,然后将数组提供给优化器操作。
for epoch in range(nb_epoch):
sess_iter_in = tf.Session()
sess_iter_out = tf.Session()
batched_train = data.get_batched_data(epoch)
iterator_train = batched_train.make_one_shot_iterator()
in_data, out_data = iterator_train.get_next()
total_batch = int(data.size_training / batch_size)
for batch in range(total_batch):
print(f"{batch + 1} / {total_batch}")
in_images = sess_iter_in.run(in_data).reshape((-1, 64, 64, 1))
out_images = sess_iter_out.run(out_data).reshape((-1, 64, 64, 1))
sess.run(optimizer, feed_dict={inputs: in_images,
outputs: out_images})
我需要什么 ?
我需要一个仅加载当前批次的图像的管道(否则它将不适合内存),并且我想为每个时期以不同的方式对数据集进行洗牌。
疑问和问题
第一个问题,我是否以良好的方式使用 Dataset 类?我在互联网上看到了非常不同的东西,例如this https://towardsdatascience.com/how-to-use-dataset-in-tensorflow-c758ef9e4428博客文章数据集与占位符一起使用,并在学习过程中使用数据进行馈送。这看起来很奇怪,因为数据都在一个数组中,所以加载到内存中。我不明白使用的意义tf.data.dataset
在这种情况下。
我通过使用找到了解决方案repeat(epoch)
在数据集上,例如this https://stackoverflow.com/a/47217160/10528024,但在这种情况下,每个时期的洗牌不会不同。
我的实施的第二个问题是我有一个OutOfRangeError
在某些情况下。对于少量数据(如示例中的 512),它可以正常工作,但是对于较大量的数据,就会出现错误。我认为这是因为由于四舍五入错误而导致批次数计算错误,或者当最后一个批次的数据量较小时,但它发生在 115 个批次中的第 32 个批次中......有什么方法可以知道之后创建的批次数batch(n)
调用数据集?
很抱歉问了这个冗长的问题,但这几天我一直在努力解决这个问题。