如何正确使用批处理 Tensorflow 数据集?

2023-12-21

我是 Tensorflow 和深度学习的新手,并且在 Dataset 类上遇到了困难。我尝试了很多方法,但找不到好的解决方案。

我正在尝试什么

我有大量图像 (500k+) 来训练我的 DNN。这是一个去噪自动编码器,所以我每个图像都有一对。我正在使用TF的数据集类来管理数据,但我认为我用得非常糟糕。

以下是我在数据集中加载文件名的方法:

class Data:
def __init__(self, in_path, out_path):
    self.nb_images = 512
    self.test_ratio = 0.2
    self.batch_size = 8

    # load filenames in input and outputs
    inputs, outputs, self.nb_images = self._load_data_pair_paths(in_path, out_path, self.nb_images)

    self.size_training = self.nb_images - int(self.nb_images * self.test_ratio)
    self.size_test = int(self.nb_images * self.test_ratio)

    # split arrays in training / validation
    test_data_in, training_data_in = self._split_test_data(inputs, self.test_ratio)
    test_data_out, training_data_out = self._split_test_data(outputs, self.test_ratio)

    # transform array to tf.data.Dataset
    self.train_dataset = tf.data.Dataset.from_tensor_slices((training_data_in, training_data_out))
    self.test_dataset = tf.data.Dataset.from_tensor_slices((test_data_in, test_data_out))

我有一个函数可以在每个时期调用来准备数据集。它会打乱文件名,并将文件名转换为图像和批处理数据。

def get_batched_data(self, seed, batch_size):
    nb_batch = int(self.size_training / batch_size)

    def img_to_tensor(path_in, path_out):
        img_string_in = tf.read_file(path_in)
        img_string_out = tf.read_file(path_out)
        im_in = tf.image.decode_jpeg(img_string_in, channels=1)
        im_out = tf.image.decode_jpeg(img_string_out, channels=1)
        return im_in, im_out

    t_datas = self.train_dataset.shuffle(self.size_training, seed=seed)
    t_datas = t_datas.map(img_to_tensor)
    t_datas = t_datas.batch(batch_size)
    return t_datas

现在在训练期间,在每个时期我们称get_batched_data函数,创建一个迭代器,并为每个批次运行它,然后将数组提供给优化器操作。

for epoch in range(nb_epoch):
    sess_iter_in = tf.Session()
    sess_iter_out = tf.Session()

    batched_train = data.get_batched_data(epoch)
    iterator_train = batched_train.make_one_shot_iterator()
    in_data, out_data = iterator_train.get_next()

    total_batch = int(data.size_training / batch_size)
    for batch in range(total_batch):
        print(f"{batch + 1} / {total_batch}")
        in_images = sess_iter_in.run(in_data).reshape((-1, 64, 64, 1))
        out_images = sess_iter_out.run(out_data).reshape((-1, 64, 64, 1))
        sess.run(optimizer, feed_dict={inputs: in_images,
                                       outputs: out_images})

我需要什么 ?

我需要一个仅加载当前批次的图像的管道(否则它将不适合内存),并且我想为每个时期以不同的方式对数据集进行洗牌。

疑问和问题

第一个问题,我是否以良好的方式使用 Dataset 类?我在互联网上看到了非常不同的东西,例如this https://towardsdatascience.com/how-to-use-dataset-in-tensorflow-c758ef9e4428博客文章数据集与占位符一起使用,并在学习过程中使用数据进行馈送。这看起来很奇怪,因为数据都在一个数组中,所以加载到内存中。我不明白使用的意义tf.data.dataset在这种情况下。

我通过使用找到了解决方案repeat(epoch)在数据集上,例如this https://stackoverflow.com/a/47217160/10528024,但在这种情况下,每个时期的洗牌不会不同。

我的实施的第二个问题是我有一个OutOfRangeError在某些情况下。对于少量数据(如示例中的 512),它可以正常工作,但是对于较大量的数据,就会出现错误。我认为这是因为由于四舍五入错误而导致批次数计算错误,或者当最后一个批次的数据量较小时,但它发生在 115 个批次中的第 32 个批次中......有什么方法可以知道之后创建的批次数batch(n)调用数据集?

很抱歉问了这个冗长的问题,但这几天我一直在努力解决这个问题。


据我所知,官方表现指南 https://www.tensorflow.org/performance/datasets_performance是制作输入管道的最佳教材。

我想为每个时期以不同的方式对数据集进行洗牌。

使用 shuffle() 和 Repeat(),您可以为每个时期获得不同的洗牌模式。您可以通过以下代码确认

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4])
dataset = dataset.shuffle(4)
dataset = dataset.repeat(3)

iterator = dataset.make_one_shot_iterator()
x = iterator.get_next()

with tf.Session() as sess:
    for i in range(10):
        print(sess.run(x))

您还可以使用上面官方页面提到的 tf.contrib.data.shuffle_and_repeat 。

除了创建数据管道之外,您的代码中还存在一些问题。您将图构建与图执行混淆了。您正在重复创建数据输入管道,因此有许多与纪元一样多的冗余输入管道。您可以通过 Tensorboard 观察冗余管道。

您应该将图形构建代码放置在循环之外,如以下代码(伪代码)

batched_train = data.get_batched_data()
iterator = batched_train.make_initializable_iterator()
in_data, out_data = iterator_train.get_next()

for epoch in range(nb_epoch):
    # reset iterator's state
    sess.run(iterator.initializer)

    try:
        while True:
            in_images = sess.run(in_data).reshape((-1, 64, 64, 1))
            out_images = sess.run(out_data).reshape((-1, 64, 64, 1))
            sess.run(optimizer, feed_dict={inputs: in_images,
                                           outputs: out_images})
    except tf.errors.OutOfRangeError:
        pass

而且还有一些不重要的低效代码。您使用 from_tensor_slices() 加载了文件路径列表,因此该列表已嵌入到您的图中。 (看https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays详情)

您最好使用预取,并通过组合图表来减少 sess.run 调用。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何正确使用批处理 Tensorflow 数据集? 的相关文章

随机推荐

  • 检测 Mac OS X 上的调试器

    我试图检测我的进程是否正在调试器中运行 在 Windows 中有很多解决方案 在 Linux 中我使用 ptrace PTRACE ME 0 0 0 并检查其返回值 我没有设法在 Mac OS X 上执行相同的基本检查 我尝试使用 ptra
  • 请解释一下使用 std::ignore 的这段代码

    我正在阅读有关的文档std ignore http en cppreference com w cpp utility tuple ignore来自 cppreference 我发现很难掌握这个对象的真正目的 并且示例代码并没有很好地说明这
  • 如何使用 RxJS 显示“用户正在输入”指示器?

    我了解一点 BaconJS 但现在我尝试通过创建 用户正在输入 指示器来学习 RxJS 这很简单 可以用两个简单的规则来解释 当用户打字时 指示器应该立即可见 当用户停止键入时 指示器应该仍然可见 直到用户最后一次键入操作后 1 秒 我不确
  • 矩阵的行主布局与列主布局

    在编程密集矩阵计算时 是否有任何理由选择行优先布局而不是列优先布局 我知道 根据所选矩阵的布局 我们需要编写适当的代码来有效地使用缓存以达到速度目的 行主布局看起来更自然 更简单 至少对我来说 但是像 LAPACK 这样用 Fortran
  • 使用 Oracle 11g 的 Oracle 开发人员虚拟机

    I found here http www oracle com technetwork community developer vm index html带有 Oracle DB 和 Oracle Linux 的 VirtualBox 的
  • 编辑RefineryCMS 2.1 Menu Presenter以操作dom_id css和其他属性

    我目前正在尝试从 gem 版本更新现有的 RefineryCMS 应用程序 以便我可以同时添加 bootstrap 3 gem refinerycms gt 2 0 10 to gem refinerycms gt 2 1 0 在删除过时的
  • IO异常:Oracle升级到12g后出现Oracle Error ORA-12650

    在我们的 Oracle DB 从 11g 升级到 12g 后 我们得到了下面的堆栈跟踪 Io exception Oracle Error ORA 12650 我该如何解决这个问题 2015 10 26 14 59 36 319 RMI T
  • 需要一个插入行并返回 ID 的存储过程

    我尝试编写一个存储过程 首先将新记录插入表中 然后返回该新记录的 id 我不确定这是否是实现这一目标的正确方法和最佳方法 ALTER PROCEDURE dbo spAddAsset Name VARCHAR 500 URL VARCHAR
  • minitest - 模拟 - 期望关键字参数

    当我想验证模拟是否发送了预期的参数时 我可以这样做 mock expect fnc nil a b 但是 如果我想模拟的课程看起来像这样 class Foo def fnc a b end end 我如何模拟它并验证传递的值a b 下面是我
  • 判断一个数是完美数还是素数

    问题是 编写一个函数来判断一个数是素数还是完全数 到目前为止 我已经首先完成了完美的部分 这就是我所拥有的 include
  • Excel 如何比较 2 列范围

    我正在尝试比较 Excel 中的两组列范围 我知道标准比较公式 Eg A1 E1 我正在寻找的是以下公式的替代品 AND A1 E1 B1 F1 C1 G1 由于列数很大 我在想是否可以使用单元格范围 比 Chronocidal慢一点 只是
  • 星际争霸、帝国时代等即时战略游戏的协议是什么样的? [关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 我对这些类型的游戏的协议 和游戏循环 如何工作感兴趣 任何指示或见解表示赞赏 我猜想主循环会有一个世界状态 每秒会前进几个 滴答声 但
  • 如何阻止 gke-metadata-server 继续生成此日志?

    我创建了一个部署 意味着在启用工作负载身份的情况下将消息从 pubsub 插入到 bigquery 云日志不断向我发送此类日志 insertId test jsonPayload message rpc id test computeMet
  • 是否有关于 Dagger 在注入依赖项时何时回退到反射的文档?

    我的团队在我们的 Android 应用程序中采用了 Dagger 进行依赖注入 我必须说到目前为止我们很喜欢它 然而 我们希望确保我们有效地使用它 我想知道是否有人可以解释或者是否有任何文档解释 Dagger 回退到反射来注入依赖项的情况
  • 如何使用 EmberJS 在路由中加载 ownTo/hasMany 关系

    在我的 Ember JS 应用程序中 我显示了一个约会列表 在约会控制器的操作中 我需要获取约会所有者 但所有者始终返回 未定义 我的文件 模型 appointment js import DS from ember data export
  • R CMD 检查因“未定义的导出”而失败

    我正在尝试创建 R 包 但不断收到错误 Error in namespaceExport ns exports undefined exports MCLE defineFunctions naiveMLE 跑步时R CMD check在我
  • JPA2 Criteria API 运行时从 varchar(25) 转换为十进制

    因此 我已经看到类似主题上堆栈溢出的所有线程 但我没有找到解决我的问题的方法 我正在尝试创建一个 Criteria 查询 并得到以下 SQL 第一个 SQL 简化版 SELECT latitude FROM stations WHERE A
  • Javascript momentjs 将 UTC 从字符串转换为日期对象

    各位 在处理 moment js 文档时遇到困难 record lastModified moment utc format returns 2014 11 11T21 29 05 00 00 太棒了 它是 UTC 当我将其存储在 Mong
  • 在 Windows 上使用 C 将数据流式传输到声卡 [关闭]

    Closed 此问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 作为大学项目的一部分 我必须进行一些信号处理 并希望使用 PC 声卡输出结果 该软件必须用 C 语言编写
  • 如何正确使用批处理 Tensorflow 数据集?

    我是 Tensorflow 和深度学习的新手 并且在 Dataset 类上遇到了困难 我尝试了很多方法 但找不到好的解决方案 我正在尝试什么 我有大量图像 500k 来训练我的 DNN 这是一个去噪自动编码器 所以我每个图像都有一对 我正在