使用 TensorFlow Dataset API 和 flat_map 的并行线程

2024-01-30

我正在将 TensorFlow 代码从旧的队列接口更改为新的数据集API https://www.tensorflow.org/api_docs/python/tf/data/Dataset。使用旧界面我可以指定num_threads论证tf.train.shuffle_batch队列。然而,控制 Dataset API 中线程数量的唯一方法似乎是在map函数使用num_parallel_calls争论。但是,我正在使用flat_map函数代替,它没有这样的参数。

Question: 有没有办法控制线程/进程的数量flat_map功能?或者有什么办法可以使用map结合flat_map并仍然指定并行调用的数量?

请注意,并行运行多个线程至关重要,因为我打算在数据进入队列之前在 CPU 上运行大量预处理。

那里有两个 (here https://github.com/tensorflow/tensorflow/issues/7951#issuecomment-305796971 and here https://github.com/tensorflow/tensorflow/issues/7951#issuecomment-326098305)GitHub 上的相关帖子,但我认为他们没有回答这个问题。

这是我的用例的最小代码示例以供说明:

with tf.Graph().as_default():
    data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
    input_tensors = (data,)

    def pre_processing_func(data_):
        # normally I would do data-augmentation here
        results = (tf.expand_dims(data_, axis=0),)
        return tf.data.Dataset.from_tensor_slices(results)

    dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
    dataset = dataset_source.flat_map(pre_processing_func)
    # do something with 'dataset'

据我所知,目前flat_map不提供并行选项。 鉴于大部分计算是在pre_processing_func,您可以使用并行作为解决方法map调用后进行一些缓冲,然后使用flat_map使用负责平坦化输出的恒等 lambda 函数进行调用。

In code:

NUM_THREADS = 5
BUFFER_SIZE = 1000

def pre_processing_func(data_):
    # data-augmentation here
    # generate new samples starting from the sample `data_`
    artificial_samples = generate_from_sample(data_)
    return atificial_samples

dataset_source = (tf.data.Dataset.from_tensor_slices(input_tensors).
                  map(pre_processing_func, num_parallel_calls=NUM_THREADS).
                  prefetch(BUFFER_SIZE).
                  flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x)).
                  shuffle(BUFFER_SIZE)) # my addition, probably necessary though

注意(对我自己和任何试图理解管道的人):

Since pre_processing_func从初始样本开始生成任意数量的新样本(以形状矩阵组织)(?, 512)), the flat_map需要调用才能将所有生成的矩阵转换为Datasets 包含单个样本(因此tf.data.Dataset.from_tensor_slices(x)在 lambda 中),然后将所有这些数据集扁平化为一个大数据集Dataset包含单独的样本。

这可能是个好主意.shuffle()该数据集或生成的样本将打包在一起。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 TensorFlow Dataset API 和 flat_map 的并行线程 的相关文章

  • django_openid_auth TypeError openid.yadis.manager.YadisServiceManager 对象不是 JSON 可序列化

    I used django openid auth在我的项目上 一段时间以来它运行得很好 但今天 我测试了该应用程序并遇到了这个异常 Environment Request Method GET Request URL http local
  • 如何在 AWS CDK 创建的 Python Lambda 函数中安装外部模块?

    我在 Cloud9 中使用 Python AWS CDK 并且我部署简单的 Lambda 函数那应该是发送 API 请求到 Atlassian 的 API当对象上传到 S3 存储桶时 也是由 CDK 创建的 这是我的 CDK 堆栈代码 fr
  • python future 和元组解包

    实现像使用 future 进行元组解包这样的事情的优雅 惯用的方法是什么 我有这样的代码 a b c f x y g a b z h y c 我想将其转换为使用期货 理想情况下我想写一些类似的东西 a b c ex submit f x y
  • Python模块可以访问英语词典,包括单词的定义[关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 我正在寻找一个 python 模块 它可以帮助我从英语词典中获取单词的定义 当然有enchant 这可以帮助我检查该单词是否存在于英语中
  • Python逻辑运算符优先级[重复]

    这个问题在这里已经有答案了 哪个运算符优先4 gt 5 or 3 lt 4 and 9 gt 8 这会被评估为真还是假 我知道该声明3 gt 4 or 2 lt 3 and 9 gt 10 显然应该评估为 false 但我不太确定 pyth
  • 如何使用 imaplib 获取“消息 ID”

    我尝试获取一个在操作期间不会更改的唯一 ID 我觉得UID不好 所以我认为 Message ID 是正确的 但我不知道如何获取它 我只知道 imap fetch uid XXXX 有人有解决方案吗 来自 IMAP 文档本身 IMAP4消息号
  • Pandas 中允许重复列

    我将一个大的 CSV 包含股票财务数据 文件分割成更小的块 CSV 文件的格式不同 像 Excel 数据透视表之类的东西 第一列的前几行包含一些标题 公司名称 ID 等在以下列中重复 因为一家公司有多个属性 而不是一家公司只有一栏 在前几行
  • 如何计算numpy数组中元素的频率?

    我有一个 3 D numpy 数组 其中包含重复的元素 counterTraj shape 13530 1 1 例如 counterTraj 包含这样的元素 我只显示了几个元素 array 136 129 130 103 102 101 我
  • 为什么Python的curses中escape键有延迟?

    In the Python curses module I have observed that there is a roughly 1 second delay between pressing the esc key and getc
  • 以同步方式使用 FastAPI,如何获取 POST 请求的原始正文?

    在中使用 FastAPIsync not async模式 我希望能够接收 POST 请求的原始 未更改的正文 我能找到的所有例子都显示async代码 当我以正常同步方式尝试时 request body 显示为协程对象 当我通过发布一些内容来
  • Python urllib.request.urlopen:AttributeError:'bytes'对象没有属性'data'

    我正在使用 Python 3 并尝试连接到dstk 我收到错误urllib包裹 我对SO进行了很多研究 但找不到与这个问题类似的东西 api url self api base street2coordinates api body jso
  • 从 python 发起 SSH 隧道时出现问题

    目标是在卫星服务器和集中式注册数据库之间建立 n 个 ssh 隧道 我已经在我的服务器之间设置了公钥身份验证 因此它们只需直接登录而无需密码提示 怎么办 我试过帕拉米科 它看起来不错 但仅仅建立一个基本的隧道就变得相当复杂 尽管代码示例将受
  • 使用鼻子获取设置中当前测试的名称

    我目前正在使用鼻子编写一些功能测试 我正在测试的库操作目录结构 为了获得可重现的结果 我存储了一个测试目录结构的模板 并在执行测试之前创建该模板的副本 我在测试中执行此操作 setup功能 这确保了我在测试开始时始终具有明确定义的状态 现在
  • Seaborn Pairplot 图例不显示颜色

    我一直在学习如何在Python中使用seaborn和pairplot 这里的一切似乎都工作正常 但由于某种原因 图例不会显示相关的颜色 我无法找到解决方案 因此如果有人有任何建议 请告诉我 x sns pairplot stats2 hue
  • 无法在 osx-arm64 上安装 Python 3.7

    我正在尝试使用 Conda 创建一个带有 Python 3 7 的新环境 例如 conda create n qnn python 3 7 我收到以下错误 Collecting package metadata current repoda
  • 当鼠标悬停在上面时,intellisense vscode 不显示参数或文档

    我正在尝试将整个工作流程从 Eclipse 和 Jupyter Notebook 迁移到 VS Code 我安装了 python 扩展 它应该带有 Intellisense 但它只是部分更糟糕 我在输入句点后收到建议 但当将鼠标悬停在其上方
  • 在Python中按属性获取对象列表中的索引

    我有具有属性 id 的对象列表 我想找到具有特定 id 的对象的索引 我写了这样的东西 index 1 for i in range len my list if my list i id specific id index i break
  • 具有自定义值的 Django 管理外键下拉列表

    我有 3 个 Django 模型 class Test models Model pass class Page models Model test models ForeignKey Test class Question model M
  • 如何读取Python字节码?

    我很难理解 Python 的字节码及其dis module import dis def func x 1 dis dis func 上述代码在解释器中输入时会产生以下输出 0 LOAD CONST 1 1 3 STORE FAST 0 x
  • 您可以使用关键字参数而不提供默认值吗?

    我习惯于在 Python 中使用这样的函数 方法定义 def my function arg1 None arg2 default do stuff here 如果我不供应arg1 or arg2 那么默认值None or default

随机推荐

  • NGRX - 如何计算商店中商品的属性

    我们在 Angular 应用程序中使用 NGRX 数据来自 API 某些属性以未格式化的字符串形式来自 API 因此我们需要对其进行格式化 当然 这可以在 HTML 中完成 但问题是在 HTML 和 TypeScript 中的多个位置都需要
  • ZF2 - 使用导航视图助手的多个导航菜单

    我正在尝试将主导航与子菜单结合使用以进行更具体的导航 In my layout我这样调用视图助手 this gt navigation main navigation gt menu 并在我的view我这样称呼它 this gt navig
  • 访问 Spark RDD 时闭包中局部变量的使用

    我有一个关于访问 Spark RDD 时闭包中局部变量的使用的问题 我想解决的问题如下 我有一个应该读入 RDD 的文本文件列表 但是 首先我需要向从单个文本文件创建的 RDD 添加附加信息 此附加信息是从文件名中提取的 然后 使用 uni
  • 将 ShapeRenderer.begin/end 嵌套在 SpriteBatch.begin/end 中

    是否可以使用绘制形状ShapeRenderer之间SpriteBatch begin and end calls 我已经尝试过 但没有结果 只绘制了 SpriteBatch 纹理 场景中没有形状 示例代码如下 shapeRenderer b
  • Objective-C 接口 - 声明变量与仅声明属性?

    在 Obj c 中 在 interface 中声明变量时 接口 NSObject 我的对象 我的对象 property 不安全 非原子 MyObject myObject 对比 仅将其声明为属性 接口 NSObject property 不
  • 通过 HTTP 回调函数进行 Google 地理编码?

    我想使用谷歌地理编码via HTTP在我的 AJAX Web 应用程序中将城市名称转换为经度和纬度的功能 但是 HTTP 地理编码器功能似乎不存在回调函数 http code google com apis maps documentati
  • 无法将操作提供者强制转换为共享操作提供者

    下面是我的活动的代码 import android app Activity import android os Bundle import android support v7 widget ShareActionProvider imp
  • Clojure 协议与 Scala 结构类型

    看完后里奇 希基的采访 http www infoq com interviews hickey clojure protocols on 协议 http clojure org protocols在 Clojure 1 2 中 对 Clo
  • 线程无异常地死亡

    我的一些工作线程遇到问题 我在线程的 run 方法中添加了一个包罗万象的异常语句 如下所示 try Runs the worker process which is a state machine while self set exitco
  • 为什么 ListView 项目不会增长以包裹其内容?

    我有一个相当复杂的 ListView 具有可变的列表项高度 在某些情况下 我需要在列表项中显示一个附加视图 该视图默认是隐藏的 View GONE 通过启用它 View VISIBLE 列表项的高度会增加 或者至少应该如此 问题 即使我将项
  • 如何将文件夹结构复制到另一个目录下?

    我有一些与复制文件夹结构相关的问题 事实上 我需要将pdf文件转换为文本文件 因此 我导入 pdf 的位置有这样的文件夹结构 D f subfolder1 subfolder2 a pdf 我想在 下创建确切的文件夹结构D g subfol
  • 在 Hibernate 中运行时急切加载整个对象图

    在说出 指定查询中的获取类型 之类的内容之前 请先阅读以下内容 那不是我所追求的 我正在寻找一种方法来急切加载完整的对象图 对象 它的所有子对象以及它们的所有子对象等等 I do not想要枚举要加载的所有属性 直到运行时我才认识它们 N
  • Linux中如何保护进程间共享的内存

    在 Linux 或其他现代操作系统中 每个进程的内存都受到保护 因此一个进程中的疯狂写入不会导致任何其他进程崩溃 现在假设我们在进程 A 和进程 B 之间共享内存 现在假设 由于软错误 进程 A 无意中向该内存区域写入了一些内容 鉴于进程
  • 如何在 React 中呈现未定义状态的数据?

    我正在 componentDidMount 内获取数据 但在初始渲染期间我未定义 然后再次渲染发生 并且在此期间状态变量被填充 现在 当它不是未定义的并且在填充之后 我想对其进行解构并在我的组件内显示数据 注意 getProjectDeta
  • C++ 中的变量作用域?

    在 C 中 main 中声明的任何变量都可以在整个 main 中使用 对吗 我的意思是 如果变量是在 try 循环中声明的 它们仍然可以在整个 main 中访问吗 因为我在 main 的 try 循环中声明了几个变量 但是如果我在 try
  • 如何获取字典中键的ReadOnlyCollection

    我的课程包含一个Dictionary
  • python time.strftime %z 始终为零而不是时区偏移

    gt gt gt import time gt gt gt t 1440935442 gt gt gt time strftime Y m d H M S z time gmtime t 2015 08 30 11 50 42 0000 g
  • 按实体名称和上次修改日期搜索

    我在 RavenDb 中存储了许多命令 它们都实现了 ICommand 我希望能够搜索上次修改的元数据和 Raven Entity Name 我目前正在对每个命令进行多重映射 如下所示 public class CommandAuditSe
  • 将 PHP 日期范围转换为 MYSQL 单个日期

    我有一个可用性日历 其中我当前正在逐个添加日期 并使用 mysql 查询来确定是否存在具有特定日期的行 并将当天的类别更改为 已预订 红色 我想在我的表单中输入一个范围 并通过 php 或 mysql 将其处理为多个单独的日期 我的日期格式
  • 使用 TensorFlow Dataset API 和 flat_map 的并行线程

    我正在将 TensorFlow 代码从旧的队列接口更改为新的数据集API https www tensorflow org api docs python tf data Dataset 使用旧界面我可以指定num threads论证tf