如何在tensorflow中使用自定义python函数预取数据

2024-02-03

我正在尝试预取训练数据以隐藏 I/O 延迟。我想编写自定义 Python 代码来从磁盘加载数据并预处理数据(例如通过添加上下文窗口)。换句话说,一个线程进行数据预处理,另一个线程进行训练。这在 TensorFlow 中可能吗?

更新:我有一个基于@mrry 示例的工作示例。

import numpy as np
import tensorflow as tf
import threading

BATCH_SIZE = 5
TRAINING_ITERS = 4100

feature_input = tf.placeholder(tf.float32, shape=[128])
label_input = tf.placeholder(tf.float32, shape=[128])

q = tf.FIFOQueue(200, [tf.float32, tf.float32], shapes=[[128], [128]])
enqueue_op = q.enqueue([label_input, feature_input])

label_batch, feature_batch = q.dequeue_many(BATCH_SIZE)
c = tf.reshape(feature_batch, [BATCH_SIZE, 128]) + tf.reshape(label_batch, [BATCH_SIZE, 128])

sess = tf.Session()

def load_and_enqueue(sess, enqueue_op, coord):
  with open('dummy_data/features.bin') as feature_file, open('dummy_data/labels.bin') as label_file:
    while not coord.should_stop():
      feature_array = np.fromfile(feature_file, np.float32, 128)
      if feature_array.shape[0] == 0:
        print('reach end of file, reset using seek(0,0)')
        feature_file.seek(0,0)
        label_file.seek(0,0)
        continue
      label_value = np.fromfile(label_file, np.float32, 128)

      sess.run(enqueue_op, feed_dict={feature_input: feature_array,
                                      label_input: label_value})

coord = tf.train.Coordinator()
t = threading.Thread(target=load_and_enqueue, args=(sess,enqueue_op, coord))
t.start()

for i in range(TRAINING_ITERS):
  sum = sess.run(c)
  print('train_iter='+str(i))
  print(sum)

coord.request_stop()
coord.join([t])

这是一个常见的用例,大多数实现都使用 TensorFlowqueues将预处理代码与训练代码分离。有有关如何使用队列的教程 https://www.tensorflow.org/versions/master/how_tos/threading_and_queues/index.html,但主要步骤如下:

  1. 定义一个队列,q,这将缓冲预处理的数据。 TensorFlow 支持简单的tf.FIFOQueue https://www.tensorflow.org/versions/master/api_docs/python/io_ops.html#FIFOQueue按照元素入队的顺序生成元素,更高级的tf.RandomShuffleQueue https://www.tensorflow.org/versions/master/api_docs/python/io_ops.html#RandomShuffleQueue以随机顺序生成元素。队列元素是一个或多个张量(可以具有不同类型和形状)的元组。所有队列都支持单元素(enqueue, dequeue)和批次(enqueue_many, dequeue_many) 操作,但要使用批处理操作,您必须在构造队列时指定队列元素中每个张量的形状。

  2. 构建一个子图,将预处理的元素排入队列。做到这一点的一种方法是定义一些tf.placeholder() https://www.tensorflow.org/versions/master/api_docs/python/io_ops.html#placeholder对应于单个输入示例的张量的操作,然后将它们传递给q.enqueue() https://www.tensorflow.org/versions/master/api_docs/python/io_ops.html#QueueBase.enqueue。 (如果您的预处理一次产生一批,您应该使用q.enqueue_many() https://www.tensorflow.org/versions/master/api_docs/python/io_ops.html#QueueBase.enqueue_many)您还可以在此子图中包含 TensorFlow 操作。

  3. 构建一个执行训练的子图。这看起来像一个常规的 TensorFlow 图,但将通过调用获取其输入q.dequeue_many(BATCH_SIZE) https://www.tensorflow.org/versions/master/api_docs/python/io_ops.html#QueueBase.dequeue_many.

  4. 开始您的会话。

  5. 创建一个或多个执行预处理逻辑的线程,然后执行入队操作,输入预处理的数据。您可能会发现tf.train.Coordinator https://www.tensorflow.org/versions/master/api_docs/python/train.html#Coordinator and tf.train.QueueRunner https://www.tensorflow.org/versions/master/api_docs/python/train.html#QueueRunner对此有用的实用程序类。

  6. 正常运行您的训练图(优化器等)。

EDIT:这是一个简单的load_and_enqueue()帮助您入门的函数和代码片段:

# Features are length-100 vectors of floats
feature_input = tf.placeholder(tf.float32, shape=[100])
# Labels are scalar integers.
label_input = tf.placeholder(tf.int32, shape=[])

# Alternatively, could do:
# feature_batch_input = tf.placeholder(tf.float32, shape=[None, 100])
# label_batch_input = tf.placeholder(tf.int32, shape=[None])

q = tf.FIFOQueue(100, [tf.float32, tf.int32], shapes=[[100], []])
enqueue_op = q.enqueue([feature_input, label_input])

# For batch input, do:
# enqueue_op = q.enqueue_many([feature_batch_input, label_batch_input])

feature_batch, label_batch = q.dequeue_many(BATCH_SIZE)
# Build rest of model taking label_batch, feature_batch as input.
# [...]
train_op = ...

sess = tf.Session()

def load_and_enqueue():
  with open(...) as feature_file, open(...) as label_file:
    while True:
      feature_array = numpy.fromfile(feature_file, numpy.float32, 100)
      if not feature_array:
        return
      label_value = numpy.fromfile(feature_file, numpy.int32, 1)[0]

      sess.run(enqueue_op, feed_dict={feature_input: feature_array,
                                      label_input: label_value})

# Start a thread to enqueue data asynchronously, and hide I/O latency.
t = threading.Thread(target=load_and_enqueue)
t.start()

for _ in range(TRAINING_EPOCHS):
  sess.run(train_op)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在tensorflow中使用自定义python函数预取数据 的相关文章

  • 为什么最新的 Python 3.8.x 版本不提供 Windows 安装程序?

    我需要在Windows计算机上安装Python 3 8并希望使用最新的小版本3 8 12 https www python org downloads release python 3812 官方发布网页提供了源代码的 tarball 文件
  • matplotlib get_color 用于子图

    我正在按照这里的教程进行操作 https matplotlib org gallery ticks and spines multiple yaxis with spines html https matplotlib org galler
  • Flask - 如何从 JSON GET 请求获取参数

    我有一个发出以下 GET 请求的客户端 api GET tasks 5fe7eabd 842e 40d2 849e 409655e0891d 22task 22 22hello 22 22url 22 22 tasks 5fe7eabd 8
  • Python绕相机轴旋转图像

    假设我有一个图像 是在对某些原始图像应用单应性变换 H 后获得的 未显示原始图像 将单应性 H 应用于原始图像的结果是该图像 我想围绕合适的轴 可能是相机所在的位置 如果有的话 将此图像旋转 30 度以获得此图像 如果我不知道相机参数 如何
  • 如何逐行替换(更新)文件中的文本

    我试图通过读取每一行 测试它 然后写入是否需要更新来替换文本文件中的文本 我不想保存为新文件 因为我的脚本已经先备份文件并对备份进行操作 这是我到目前为止所拥有的 我从 os walk 获取路径 并且保证 pathmatch var 正确返
  • 为什么 np.linalg.norm(..., axis=1) 比写出向量范数公式慢?

    标准化矩阵的行X对于单位长度 我通常使用 X np linalg norm X axis 1 keepdims True 在尝试优化算法的此操作时 我非常惊讶地发现在我的机器上写出标准化的速度大约快了 40 X np sqrt X 0 2
  • 使用文本和进度条填充 DataGridView

    我正在创建一个多线程应用程序 其中每个线程将在我的应用程序中显示为一行DataGridView 我想要一个ProgressBar每行指示相应的线程进度 问题是 这可能吗 如果是这样 怎么办 我添加了类 DataGridView Progre
  • UTF-8 解码如何知道字节边界?

    我一直在阅读大量有关 unicode 编码的文章 尤其是有关 Python 的文章 我想我现在对此已经有了相当深入的了解 但仍有一个小细节我有点不确定 解码如何知道字节边界 例如 假设我有一个带有两个 unicode 字符的 unicode
  • Python 将 0 计算为 False

    在 Python 控制台中 gt gt gt a 0 gt gt gt if a print L gt gt gt a 1 gt gt gt if a print L L gt gt gt a 2 gt gt gt if a print L
  • Django 单元测试数据库没有被拆除?

    我编写了一些单元测试来测试我的 Django 应用程序 特别是一个测试套件中包含大量代码setUp 功能 所述代码的目的是为数据库创建测试数据 是的 我了解固定装置 并且选择在这种情况下不使用它们 当我运行单元测试套件时 运行的第一个测试通
  • Python textwrap.wrap 导致 \n 问题

    所以我只是重新格式化了一堆代码以合并textwrap wrap 却发现我所有的 n都消失了 这是一个例子 from textwrap import wrap def wrapAndPrint msg width 25 wrap msg to
  • Web 应用程序框架:C++ 与 Python

    作为一名程序员 我熟悉 Python 和 C 我正在考虑编写自己的简单 Web 应用程序 并且想知道哪种语言更适合服务器端 Web 开发 我正在寻找一些东西 它必须是直观的 我认识到 Wt 存在并且它遵循 Qt 的模型 我讨厌 Qt 的一件
  • Python 日志记录 - 如何检查记录器是否为空

    我刚刚在我的应用程序中实现了日志记录 我想知道是否有一种方法可以检查记录器是否为空 我的想法是在我的脚本中设置两个处理程序 一个用于带水平仪的控制台WARNING 一个用于带级别的文件DEBUG 在脚本的最后 我需要检查是否CONSOLE记
  • 即使使用标头和 Session 对象,Python requests.get 也会失败并出现 403 禁止

    我正在发出 GET 请求来获取 JSON 它在任何设备上的任何浏览器中都可以正常工作 但不能通过 python 请求 url https angel co autocomplete new tags params query sci tag
  • Kivy错误(python 2.7):sdl2导入错误

    我尝试在我的 Python 2 7 项目 在 PyCharm Windows 10 环境中 上使用 kivy 但出现以下错误 如果有人可以帮助我吗 谢谢 PS 我多次尝试卸载 重新安装库等 并按照像这样的帖子上的建议进行操作 但它不起作用
  • Python 多处理:全局对象未正确复制到子级

    前几天我回答了一个关于SO的问题 https stackoverflow com q 67047533 1925388关于并行读取 tar 文件 这是问题的要点 import bz2 import tarfile from multipro
  • pygame.image.load 不工作

    我正在尝试为游戏创建世界地图 但是当我尝试将世界地图加载到屏幕上时 命令行告诉我无法执行此操作 这是代码 import sys import pygame from pygame locals import pygame init Surf
  • 类型错误:不可散列的类型:pandas 的“切片”

    我有一个 pandas 数据结构 我这样创建 test inputs pd read csv input test csv delimiter 它的形状 print test inputs shape is this 28000 784 我
  • 按键合并的两个字典的值的并集

    我有两本词典 d1 a x y b k l d2 a m n c p r 如何合并这两个字典以获得这样的结果 d3 a x y m n b k l c p r 当字典的值是简单类型 如 int 或 str 时 这有效 d3 dict i a
  • Pandas 将时间序列数据重新采样为 15 分钟和 45 分钟 - 使用多索引或列

    我有一些时间序列数据作为 Pandas 数据框 它从每小时过去 15 分钟和过去 45 分钟 时间间隔为 30 分钟 的观察开始 然后将频率更改为每分钟 我想对数据进行重新采样 以便整个数据帧的频率为每 30 分钟一次 15 点和 45 点

随机推荐

  • 使 CSS 三角形垂直重复(锯齿图案)

    我有这样的导航 我想要 gt gt gt gt gt 我认为将其作为单独的 div 来完成是最简单的 第二个只关心沿着导航长度重复一个模式 我在这里寻求帮助 但我发现的大多数文章都是关于水平重复三角形的 我喜欢这个解决方案http jsfi
  • whoosh 是否要求所有字符串都是 unicode ?

    我正在 Solr 的 Whoosh 中重做我的搜索应用程序 我现在正在学习快速开始 但每次我不得不处理字符串时我都会遇到问题 gt gt gt writer add document iden fil content F2T file to
  • WSO2 ESB 4.9.0:错误 101500 意味着什么

    在连接到服务器时 我们会收到如下错误
  • VBA AddressOf 崩溃 Office 应用程序

    我想运行一个简单的代码片段 但每次 Access 和 Excel 都会崩溃 我正在运行回调测试 2 您能帮我一下吗 多谢 Declare Function CallWindowProc Lib user32 Alias CallWindow
  • 为什么 Int 不继承/扩展 Ordered[Int]

    我有一个关于字体设计的问题 为什么 Int 不扩展 Ordered 特征 Int 不是天生有序的吗 相反 scala 库提供了隐式 orderer 方法 将 Int 转换为 Ordered Int 这里做出了哪些设计选择 示例取自 Scal
  • 禁用单选按钮单击上的下拉框

    我有两个单选按钮和一个下拉框 如下所示 我想做的是 1 选中 否 时 隐藏或灰显下拉框 然后 2 当选中 是 时 显示下拉框 任何指示将不胜感激 td td
  • 当推送到heroku时,有没有办法将资产管道资产转移到s3?

    有没有一种简单的方法可以通过资产管道并部署到heroku s3 我希望我的本地 Rails 应用程序能够正常在本地使用 image css js 当您预编译时 生产应用程序是否有一种简单的方法可以从 s3 提供其资产 而开发资产是本地的 而
  • 快速除以 10ˣ

    In my program I use a lot of integer division by 10x and integer mod function of power 10 例如 unsigned int64 a 12345 a a
  • 结构(差异列表) Prolog

    这个问题参考了本书第三章的材料 Prolog Clocksin 和 Mellish 编程 Ed 5 本书第72页显示了一个使用差异列表的程序 partsOf X P partsacc X P Hole Hole partsacc X X H
  • 为什么 fputs 和 fprintf 反转流顺序

    我不明白为什么 fputs 和 fprintf 反转流顺序 int fputs const char str FILE stream int fprintf FILE stream const char format ssize t wri
  • Github:分叉队列与拉取请求

    我正在与朋友在 Github 上开始一个项目 到目前为止 他已经创建了存储库 我也已经分叉了它 我开始对我的存储库进行更改 提交并将更改推送到源 我的分叉副本 我们现在准备将这些更改集成到他的原始存储库中 Fork 队列和 Pull 请求有
  • 解析格式奇怪的日期时间。有人想上前吗?

    我正在尝试解析日期戳 我从 Twitter 获得 但收到错误 这是日期戳 2010 年 8 月 27 日星期五 22 00 07 0000 这是我的代码 DateTime ParseExact MyDateValue ddd dd MMM
  • Laravel:具有共享表的多对多

    I have Locations模型有很多Employees 相似地Employees属于Locations 这很好而且效果很好 但后来我考虑添加PhoneNumbers Either a Location or an Employee可以
  • Debezium-不含连接器类型

    我正在尝试使用 Debezium 连接到本地计算机上的 mysql 数据库 尝试使用以下命令调用kafka sudo kafka bin connect standalone shsh kafka config connect standa
  • 逐行读取文件

    我正在尝试逐行读取文件 但我不知道如何在到达 EOF 时停止 我有这个代码 readWholeFile do inputFile lt openFile example txt ReadMode readALine inputFile re
  • 无法解决原木锻造强化问题

    我在修复 Fortify 中的日志锻造问题时遇到问题 getLongFromTimestamp 方法中的两个日志记录调用都引发了 将未经验证的用户输入写入日志 的问题 public long getLongFromTimestamp fin
  • 使用多个 EJB 引用部署 Ear Web 应用程序时出现 Glassfish 错误

    继续部署 Ear Web 应用程序时 Glassfish 出错 https stackoverflow com questions 52400938 glassfish error while deploying ear web appli
  • 为什么 C# .NET SortedList 实际上没有 ElementAt?

    3 5 Collections Generic SortedList 的 NET 文档 http msdn microsoft com en us library ms132320 28v vs 90 29 aspx 在文档中 它明确指出
  • 使用 Oauth 在 Node.js 中将图像发布到 twitter

    我正在尝试使用 Oauth 模块将图像发布到 Twitter 这是我所拥有的 它抛出了 403 错误 我知道我在将媒体添加到帖子中的方式上做错了 但我只是不确定从这里该去哪里 var https require https var OAut
  • 如何在tensorflow中使用自定义python函数预取数据

    我正在尝试预取训练数据以隐藏 I O 延迟 我想编写自定义 Python 代码来从磁盘加载数据并预处理数据 例如通过添加上下文窗口 换句话说 一个线程进行数据预处理 另一个线程进行训练 这在 TensorFlow 中可能吗 更新 我有一个基