如何从 tfrecords 目录创建 tf.data.dataset?

2023-12-08

我的数据集有不同的目录,每个目录对应一个类。每个目录中有不同数量的 .tfrecord。我的问题是如何从每个目录中采样 5 个图像(每个 .tfrecord 文件对应一个图像)? 我的另一个问题是如何对其中 5 个目录进行采样,然后从每个目录中采样 5 个图像。

我只想用 tf.data.dataset 来做。所以我想要一个数据集,从中获得一个迭代器,并且 iterator.next() 为我提供了一批 25 个图像,其中包含来自 5 个类的 5 个样本。


EDIT:如果类的数量大于5,那么你可以使用新的tf.contrib.data.sample_from_datasets()API(当前可用tf-nightly并将在 TensorFlow 1.9 中提供)。

directories = ["class_0/*", "class_1/*", "class_2/*", "class_3/*", ...]

CLASSES_PER_BATCH = 5
EXAMPLES_PER_CLASS_PER_BATCH = 5
BATCH_SIZE = CLASSES_PER_BATCH * EXAMPLES_PER_CLASS_PER_BATCH
NUM_CLASSES = len(directories)


# Build one dataset per class.
per_class_datasets = [
    tf.data.TFRecordDataset(tf.data.Dataset.list_files(d)) for d in directories]

# Next, build a dataset where each element is a vector of 5 classes to be chosen
# for a particular batch.
classes_per_batch_dataset = tf.contrib.data.Counter().map(
    lambda _: tf.random_shuffle(tf.range(NUM_CLASSES))[:CLASSES_PER_BATCH]))

# Transform the dataset of per-batch class vectors into a dataset with one
# one-hot element per example (i.e. 25 examples per batch).
class_dataset = classes_per_batch_dataset.flat_map(
    lambda classes: tf.data.Dataset.from_tensor_slices(
        tf.one_hot(classes, num_classes)).repeat(EXAMPLES_PER_CLASS_PER_BATCH))

# Use `tf.contrib.data.sample_from_datasets()` to select an example from the
# appropriate dataset in `per_class_datasets`.
example_dataset = tf.contrib.data.sample_from_datasets(per_class_datasets,
                                 class_dataset)

# Finally, combine 25 consecutive examples into a batch.
result = example_dataset.batch(BATCH_SIZE)

如果您正好有 5 个类,则可以为每个目录定义一个嵌套数据集并使用Dataset.interleave():

# NOTE: We're assuming that the 0th directory contains elements from class 0, etc.
directories = ["class_0/*", "class_1/*", "class_2/*", "class_3/*", "class_4/*"]
directories = tf.data.Dataset.from_tensor_slices(directories)
directories = directories.apply(tf.contrib.data.enumerate_dataset())    

# Define a function that maps each (class, directory) pair to the (shuffled)
# records in those files.
def per_directory_dataset(class_label, directory_glob):
  files = tf.data.Dataset.list_files(directory_glob, shuffle=True)
  records = tf.data.TFRecordDataset(records)
  # Zip the records with their class. 
  # NOTE: This part might not be necessary if the records contain information about
  # their class that can be parsed from them.
  return tf.data.Dataset.zip(
      (records, tf.data.Dataset.from_tensors(class_label).repeat(None)))

# NOTE: The `cycle_length` and `block_length` here aren't strictly necessary,
# because the batch size is exactly `number of classes * images per class`.
# However, these arguments may be useful if you want to decouple these numbers.
merged_records = directories.interleave(per_directory_dataset,
                                        cycle_length=5, block_length=5)
merged_records = merged_records.batch(25)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何从 tfrecords 目录创建 tf.data.dataset? 的相关文章

随机推荐

  • 在网页中嵌入 Windows 窗体用户控件的步骤

    我正在 Visual Studio 2005 中开发一个 Windows 窗体用户控件 它是一个文件上传控件 仅使用 2 个元素 显示 openfiledialog 的按钮 打开文件对话框 我已经在 html 页面中添加了一个带有类 id
  • Hibernate EntityManager.merge() 不更新数据库

    我有一个使用 Hibernate 的 Spring MVC Web 应用程序 我的问题是em merge拨打电话后没有回复 这是我的控制器 RequestMapping value updDep method RequestMethod P
  • 从 Mysql DB 填充 JFreechart TimeSeriesCollection?

    我正在尝试在我的应用程序中制作一个图表 该图表可以返回几个月内各天的温度 该图表是 JFreechart TimeSeriesCollection 我无法让该图表从数据库读取正确的数据 它显示了一些值 但不是全部 并且不显示正确的时间 为了
  • 为什么 gc() 不释放内存?

    我在一个上运行模拟Windows 64 位计算机 with 64 GB 内存 内存使用达到55 完成模拟运行后 我通过以下方式删除工作空间中的所有对象rm list ls 后面跟着一个double gc 我认为这将为下一次模拟运行释放足够的
  • 如何使用特定网络接口(或特定源 IP 地址)进行 Ping?

    根据这个链接 使用 System Net NetworkInformation 有没有办法将 ping 绑定到特定接口 ICMP 不能绑定到网络接口 与基于套接字的东西不同 ICMP 不是基于套接字的 ping 将根据路由表发送到适当的端口
  • 列表视图滚动不平滑

    我有一个自定义列表视图 显示用户和照片 我从 API 检索数据 它提供 JSON 输出 我的问题是列表视图滚动不顺畅 它挂起一秒钟并滚动 它重复相同的操作直到我们到达末尾 我认为这可能是因为我正在 UI 线程上运行与网络相关的操作 但即使在
  • 实体框架能否在保存时自动将日期时间字段转换为 UTC?

    我正在使用 ASP NET MVC 5 编写一个应用程序 我要存储在数据库中的所有日期时间必须首先从本地时区转换为 UTC 时区 我不确定在请求周期内最好的地方在哪里 我可以在控制器中通过 ViewModel 规则后将每个字段转换为 UTC
  • JS 中的猜数字游戏

    我想创建一个数字游戏 用户输入 1 100 之间的数字 脚本将尝试猜测 10 次用户的输入 如果猜对的数字在 10 以内 则用户获胜 否则用户获胜 到目前为止 我让它正常工作 除了我在尝试让它显示游戏结束时的猜测数量时遇到问题 例如 如果进
  • 如何使用模型/视图/控制器方法制作 GUI?

    我需要理解模型 视图 控制器方法背后的概念以及如何以这种方式编写 GUI 这只是一个非常基本 简单的 GUI 有人可以向我解释如何使用 MVC 重写这段代码吗 from tkinter import class Application Fr
  • 使用 Button Jupyter Notebook 终止循环?

    我想要 从串口读取 无限循环 当按下 STOP 按钮时 gt 停止读取并绘制数据 From 如何通过按键终止 while 循环 我以使用键盘中断为例 这有效 但我想使用一个按钮 键盘中断示例 weights times open port
  • 将 ACE 与 WT 结合使用

    UPDATE 3最终工作代码如下 您需要 src 文件夹中的 ace js 它无法从库中运行 您需要从他们的站点获得预打包版本 WText editor new WText root editor gt setText function n
  • 在 Kubernetes Python 客户端中使用 create_namespaced_secret API

    我必须创建一个像这样的秘密 但是使用Python kubectl create secret generic mysecret n mynamespace from literal etcdpasswd echo n PASSWORD ba
  • 为什么我的坐标区对象的 ButtonDownFcn 回调在绘制某些内容后停止工作?

    我正在图中创建一组轴并为其分配回调 ButtonDownFcn 像这样的财产 HRaxes axes Parent Figure Position 05 60 9 35 XLimMode manual ButtonDownFcn HR Bu
  • 在 unicode 中填充“o”字符或通过 CSS 模仿

    我需要用 HTML 编写此文本 我尝试使用一些 unicode 字符 例如Unicode字符集 黑圈 U 25CF or Unicode字符集 黑色大圆圈 U 2B24 但它们需要一些样式 即尺寸与实际尺寸不同 o 并且在某些系统和字体上显
  • 如何更改 stackplot、matplotlib 的调色板?

    我希望更改 stackplot 的调色板 使大区域具有浅色 较小区域具有明亮颜色 import numpy as np import pandas as pd import matplotlib pyplot as plt import s
  • 循环总结大于 R 中主题的观察结果

    我有一个看起来像这样的数据集 set seed 100 da lt data frame exp c rep A 4 rep B 4 diam runif 8 10 30 对于数据集中的每一行 我想总结大于特定行中的直径并包含在级别 exp
  • 如何在空手道中使用特定于环境的测试数据

    我想知道在各种环境中执行测试时如何在运行时使用不同的数据集 我已阅读文档 但无法找到针对这种情况的最佳解决方案 要求 在 QA 环境中执行测试 然后在 SIT 中执行相同的测试 但是 在请求中使用不同的数据 例如 customerIds 这
  • 如何在AngularJS中渲染之前编译过滤器的结果

    我有一个网络应用程序 用户可以在其中输入 富文本 内容 tinymce 并可能输入超链接 在我的角度应用程序中 我使用 ng bind html unsafe 在 div 中渲染它以保留所有格式 我想将 ng click 附加到该内容中的任
  • 哪个事件被触发? (javascript,输入字段历史记录)

    我有一个空的文本字段 但是当您单击它时 它会显示以前输入的一些建议 如果我用鼠标选择其中一个 JavaScript 事件 会触发哪个 JavaScript 事件 我正在使用 jquery 1 6 2 来绑定侦听器 view textRegi
  • 如何从 tfrecords 目录创建 tf.data.dataset?

    我的数据集有不同的目录 每个目录对应一个类 每个目录中有不同数量的 tfrecord 我的问题是如何从每个目录中采样 5 个图像 每个 tfrecord 文件对应一个图像 我的另一个问题是如何对其中 5 个目录进行采样 然后从每个目录中采样