Tensorflow:在CPU上的多个线程中加载数据

2024-01-14

我有一个 python 课程SceneGenerator它有多个用于预处理的成员函数和一个生成器函数generate_data()。基本结构是这样的:

class SceneGenerator(object):
    def __init__(self):
       # some inits

    def generate_data(self):
        """
        Generator. Yield data X and labels y after some preprocessing
        """
        while True:
            # opening files, selecting data
            X,y = self.preprocess(some_params, filenames, ...)            

            yield X, y

我使用keras model.fit_generator()函数中的类成员函数sceneGenerator.generate_data()从磁盘读取数据,对其进行预处理并产生它。在 keras 中,这是在多个 CPU 线程上完成的,如果workers的参数model.fit_generator()设置为 > 1。

我现在想用同样的SceneGenerator张量流中的类。我目前的做法是这样的:

sceneGenerator = SceneGenerator(some_params)
for X, y in sceneGenerator.generate_data():

    feed_dict = {ops['data']: X,
                 ops['labels']: y,
                 ops['is_training_pl']: True
                 }
    summary, step, _, loss, prediction = sess.run([optimization_op, loss_op, pred_op],
                                                  feed_dict=feed_dict)

然而,这很慢并且不使用多线程。我找到了tf.data.Dataset https://www.tensorflow.org/versions/master/api_docs/python/tf/data/Datasetapi 与一些文档 https://www.tensorflow.org/versions/master/programmers_guide/datasets,但我未能实现这些方法。

Edit:请注意,我不处理图像,因此带有文件路径等的图像加载机制在这里不起作用。 我的SceneGenerator从 hdf5 文件加载数据。但不是完整的数据集,而是 - 根据初始化参数 - 仅数据集的一部分。我很想保持生成器功能不变,并了解如何将该生成器直接用作张量流的输入并在 CPU 上的多个线程上运行。将 hdf5 文件中的数据重写为 csv 并不是一个好的选择,因为它会重复大量数据。

Edit 2::我认为类似的东西可能会有所帮助:并行化 tf.data.Dataset.from_generator https://stackoverflow.com/questions/47086599/parallelising-tf-data-dataset-from-generator


假设您使用的是最新的 Tensorflow(撰写本文时为 1.4),您可以保留生成器并使用tf.data.* https://www.tensorflow.org/api_docs/python/tf/dataAPI如下(我为线程数、预取缓冲区大小、批量大小和输出数据类型选择任意值):

NUM_THREADS = 5
sceneGen = SceneGenerator()
dataset = tf.data.Dataset.from_generator(sceneGen.generate_data, output_types=(tf.float32, tf.int32))
dataset = dataset.map(lambda x,y : (x,y), num_parallel_calls=NUM_THREADS).prefetch(buffer_size=1000)
dataset = dataset.batch(42)
X, y = dataset.make_one_shot_iterator().get_next()

为了表明它实际上是从生成器中提取的多个线程,我修改了您的类,如下所示:

import threading    
class SceneGenerator(object):
  def __init__(self):
    # some inits
    pass

  def generate_data(self):
    """
    Generator. Yield data X and labels y after some preprocessing
    """
    while True:
      # opening files, selecting data
      X,y = threading.get_ident(), 2 #self.preprocess(some_params, filenames, ...)            
      yield X, y

这样,创建一个 Tensorflow 会话并获取一批即可显示获取数据的线程的线程 ID。在我的电脑上,运行:

sess = tf.Session()
print(sess.run([X, y]))

prints

[array([  8460.,   8460.,   8460.,  15912.,  16200.,  16200.,   8460.,
         15912.,  16200.,   8460.,  15912.,  16200.,  16200.,   8460.,
         15912.,  15912.,   8460.,   8460.,   6552.,  15912.,  15912.,
          8460.,   8460.,  15912.,   9956.,  16200.,   9956.,  16200.,
         15912.,  15912.,   9956.,  16200.,  15912.,  16200.,  16200.,
         16200.,   6552.,  16200.,  16200.,   9956.,   6552.,   6552.], dtype=float32),
 array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])]

Note:您可能想尝试删除map调用(我们仅用于多线程)并检查是否prefetch的缓冲区足以消除输入管道中的瓶颈(即使只有一个线程,输入预处理通常比实际图形执行速度更快,因此缓冲区足以使预处理尽可能快地进行)。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow:在CPU上的多个线程中加载数据 的相关文章

随机推荐

  • Linux 点阵打印机上的 Java 打印质量

    我需要将报告从 Java 桌面应用程序打印到点阵打印机 Epson LX 300 II 报告由文本和一些图形组成 打印机通过 USB 连接 我使用 CUPS 进行打印 我正在使用 Printable 接口 Java 中相当标准 进行打印 我
  • 尝试在 JFrame 中显示 URL 图像

    尝试在 JFrame 窗口中显示 URL 图像 如果工作正常 当程序运行时 应该打开一个窗口显示图像 尝试尝试 URL 和硬盘路径 import java awt image BufferedImage import java io IOE
  • MVC6 Cors - 拦截飞行前

    我正在将 WebApi 升级到 MVC6 在 WebApi 中 我可以拦截每个 HTTP 请求 如果是预检 我可以使用浏览器可接受的标头进行响应 我试图弄清楚如何在 MVC6 WebApi 中做同样的事情 这是 WebApi 代码 prot
  • 似乎无法在 text() 和 textfield() 之间使用 Spacer()

    当我在 HStack 中并尝试在 Text 和 Textfield 视图之间创建空间时 我似乎无法使用 Spacer 函数 间隔器用于间隔视图的其他区域 但每当我尝试在这两个元素之间间隔时 它就不起作用 这是我正在使用的代码 VStack
  • Kendo UI MVC 4:窗口内的表单验证不会触发

    使用 ASP NET MVC 4 我声明了一个窗口 其中通过 LoadContentFrom 加载了内部内容 表单 Html Kendo Window Name windowAttachClient Title Attach Client
  • d3.json() 回调中的代码未执行

    我正在尝试加载 GeoJSON 文件并使用它作为 D3 的基础来绘制一些图形v5 问题是浏览器跳过了包含在d3 json 称呼 我尝试插入断点来测试 但浏览器会跳过它们 我不明白为什么 下面的代码片段 d3 json trip animat
  • EF5 Code First - 数据注释与 Fluent API [重复]

    这个问题在这里已经有答案了 我是实体框架新手 即将开始使用 EF5 Code First 的新 ASP NET MVC 项目 据我了解 您可以对域模型对象中的属性使用数据注释 也可以使用 Fluent API 来定义属性数据类型 创建对象时
  • 将 CSS 样式应用于 DIV 内的所有元素

    我想将 CSS 文件应用到页面中的具体 DIV 这是页面结构 div div all the elements here must follow a concrete CSS rules div 我尝试应用 CSS 文件的规则进行编辑 CS
  • Vim 输入不是来自终端[重复]

    这个问题在这里已经有答案了 which django admin py vim Vim Warning Input is not from a terminal Vim Error reading input exiting Vim Fin
  • 如何解决Java舍入双精度问题[重复]

    这个问题在这里已经有答案了 似乎减法引发了某种问题 并且结果值是错误的 double tempCommission targetPremium doubleValue rate doubleValue 100d 78 75 787 5 10
  • PHP 错误处理

    提前谢谢大家了 我目前正在调整 改进我为公司从头开始编写的 MVC 框架 它相对较新 因此肯定是不完整的 我需要将错误处理合并到框架中 一切都应该能够访问错误处理 并且它应该能够处理不同类型和级别的错误 用户错误和框架错误 我的问题是做到这
  • 有没有一种很好的方法来增加可选的 Int 值?

    我想增加一个Int 目前我已经写了这个 return index nil index 1 nil 有没有更漂亮的方法来写这个 您可以致电advanced by 函数使用可选链接 return index advancedBy 1 Note
  • 计算,用逗号替换点

    我有一个订单表格 我在其中使用 jQuery 计算插件来总结总数 这种求和工作正常 但生成的 总和 存在问题 总之 我希望用逗号替换任何点 该代码的基础是 function this var sum this sum totaal html
  • 使用 vbscript 进行进程间通信

    我需要将数据从一个进程发送到另一个进程 限制条件 发送方进程是非常昂贵的调用 需要使用 vbscipt 来完成 对于Sender进程来说 这个数据传输是一项额外的工作 它应该不会受到这个特性的太大影响 4 5 分钟内 发送方进程中大约有 1
  • 数据未转换 Node.js 转换流

    我正在尝试创建一个从以下位置获取数据的转换流socket io 将其转换为 JSON 然后将其发送到 stdout 我完全困惑为什么数据似乎没有任何转换就直接通过 我正在使用through2图书馆 这是我的代码 getStreamNames
  • 访问没有字符的字符串的第一个字符

    我正在用 C 实现后缀特里树 实施Trie构造函数如下所示 include
  • npm 错误!代码 ELIFECYCLE(起始问题)

    感谢您阅读本文并帮助解决该问题 我正在尝试在 Windows 计算机上运行 nodejs 并在安装 expo cli 后启动 expo 客户端 最初它工作正常 除了实时刷新或任何其他刷新不起作用 所以我尝试再次删除 卸载 重新安装nodej
  • 救援 CSV::MalformedCsvError:第 n 行中的非法引用

    在尝试解析数组 AR 模型导入等时 出现有问题的 CSV 文件似乎是一个常见问题 除了在 MS Excel 中打开之外 我还没有找到可行的解决方案save as每天 还不够好 在外部提供的 60 000 行 每日更新的 csv 文件中 存在
  • 喷雾罐 NoClassDefFoundError

    我是喷雾新手 我无法让它工作 我的构建 sbt val apacheDeps Seq commons validator commons validator 1 4 1 val sprayAndAkkaDeps val sprayV 1 3
  • Tensorflow:在CPU上的多个线程中加载数据

    我有一个 python 课程SceneGenerator它有多个用于预处理的成员函数和一个生成器函数generate data 基本结构是这样的 class SceneGenerator object def init self some