Tensorflow tf.data.Dataset.cache似乎没有达到预期的效果

2024-03-13

我正在尝试按照以下方法提高我的模型训练性能使用 tf.data API 获得更好的性能 https://www.tensorflow.org/guide/data_performance指导方针。然而,我观察到使用的性能.cache()如果与没有的相同设置相比,几乎相同甚至更糟.cache().

datafile_list = load_my_files()
RAW_BYTES = 403*4
BATCH_SIZE = 32

raw_dataset = tf.data.FixedLengthRecordDataset(filenames=datafile_list, record_bytes=RAW_BYTES, num_parallel_reads=10, buffer_size=1024*RAW_BYTES)
raw_dataset = raw_dataset.map(tf.autograph.experimental.do_not_convert(decode_and_prepare),
    num_parallel_calls=tf.data.AUTOTUNE)
raw_dataset = raw_dataset.cache()
raw_dataset = raw_dataset.shuffle(buffer_size=1024)
raw_dataset = raw_dataset.batch(BATCH_SIZE)
raw_dataset = raw_dataset.prefetch(tf.data.AUTOTUNE)

数据在datafile_list保留 9.92GB,相当适合系统可用的总物理 RAM (100GB)。系统交换已禁用。

通过使用数据集训练模型:

model = build_model()
model.fit(raw_dataset, epochs=5, verbose=2)

结果是:

Epoch 1/5
206247/206247 - 126s - loss: 0.0043 - mae: 0.0494 - mse: 0.0043
Epoch 2/5
206247/206247 - 125s - loss: 0.0029 - mae: 0.0415 - mse: 0.0029
Epoch 3/5
206247/206247 - 129s - loss: 0.0027 - mae: 0.0397 - mse: 0.0027
Epoch 4/5
206247/206247 - 125s - loss: 0.0025 - mae: 0.0386 - mse: 0.0025
Epoch 5/5
206247/206247 - 125s - loss: 0.0024 - mae: 0.0379 - mse: 0.0024

这个结果令人沮丧。由docs https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache:

第一次迭代数据集时,其元素将被缓存在指定文件或内存中。后续迭代将使用缓存的数据。

并从本指南 https://www.tensorflow.org/datasets/performances#caching_the_dataset:

迭代此数据集时,由于缓存,第二次迭代将比第一次迭代快得多。

然而,所有历元所花费的时间几乎相同。此外,在训练过程中,CPU 和 GPU 的使用率都非常低(见下图)。

通过注释掉该行raw_dataset = raw_dataset.cache()结果没有显示任何显着差异:

Epoch 1/5
206067/206067 - 129s - loss: 0.0042 - mae: 0.0492 - mse: 0.0042
Epoch 2/5
206067/206067 - 127s - loss: 0.0028 - mae: 0.0412 - mse: 0.0028
Epoch 3/5
206067/206067 - 134s - loss: 0.0026 - mae: 0.0393 - mse: 0.0026
Epoch 4/5
206067/206067 - 127s - loss: 0.0024 - mae: 0.0383 - mse: 0.0024
Epoch 5/5
206067/206067 - 126s - loss: 0.0023 - mae: 0.0376 - mse: 0.0023

正如文档中指出的,我的期望是使用缓存会导致训练时间更快。我想知道我做错了什么。

附件

使用缓存进行训练期间的 GPU 使用情况:

训练期间没有缓存的 GPU 使用情况:

使用缓存进行训练期间的系统统计信息(内存、CPU 等):

训练期间没有缓存的系统统计信息(内存、CPU 等):


只是使用 Google Colab 进行的一个小观察。根据docs https://www.tensorflow.org/api_docs/python/tf/data/Dataset?version=nightly#cache:

注意:为了最终确定缓存,必须完整迭代输入数据集。否则,后续迭代将不会使用缓存数据。

And

注意:缓存每次都会产生完全相同的元素 迭代数据集。如果您希望随机化迭代 order,确保在调用cache之后调用shuffle。

我确实注意到在事先使用缓存和迭代数据集时存在一些差异。这是一个例子。

准备数据:

import random
import struct
import tensorflow as tf
import numpy as np

RAW_N = 2 + 20*20 + 1

bytess = random.sample(range(1, 5000), RAW_N*4)
with open('mydata.bin', 'wb') as f:
  f.write(struct.pack('1612i', *bytess))
def decode_and_prepare(register):
  register = tf.io.decode_raw(register, out_type=tf.float32)
  inputs = register[2:402]
  label = tf.random.uniform(()) + register[402:]
  return inputs, label

raw_dataset = tf.data.FixedLengthRecordDataset(filenames=['/content/mydata.bin']*7000, record_bytes=RAW_N*4)
raw_dataset = raw_dataset.map(decode_and_prepare)

火车模型without预先缓存和迭代:

total_data_entries = len(list(raw_dataset.map(lambda x, y: (x, y))))
train_ds = raw_dataset.shuffle(buffer_size=total_data_entries).batch(32).prefetch(tf.data.AUTOTUNE)
inputs = tf.keras.layers.Input((400,))
x = tf.keras.layers.Dense(200, activation='relu', kernel_initializer='normal')(inputs)
x = tf.keras.layers.Dense(100, activation='relu', kernel_initializer='normal')(x)
outputs = tf.keras.layers.Dense(1, kernel_initializer='normal')(x)
model = tf.keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='mse')
model.fit(train_ds, epochs=5)
Epoch 1/5
875/875 [==============================] - 4s 3ms/step - loss: 0.1425
Epoch 2/5
875/875 [==============================] - 4s 3ms/step - loss: 0.0841
Epoch 3/5
875/875 [==============================] - 4s 3ms/step - loss: 0.0840
Epoch 4/5
875/875 [==============================] - 4s 3ms/step - loss: 0.0840
Epoch 5/5
875/875 [==============================] - 4s 3ms/step - loss: 0.0840
<keras.callbacks.History at 0x7fc41be037d0>

训练模型with缓存但是no迭代:

total_data_entries = len(list(raw_dataset.map(lambda x, y: (x, y))))
train_ds = raw_dataset.shuffle(buffer_size=total_data_entries).cache().batch(32).prefetch(tf.data.AUTOTUNE)
inputs = tf.keras.layers.Input((400,))
x = tf.keras.layers.Dense(200, activation='relu', kernel_initializer='normal')(inputs)
x = tf.keras.layers.Dense(100, activation='relu', kernel_initializer='normal')(x)
outputs = tf.keras.layers.Dense(1, kernel_initializer='normal')(x)
model = tf.keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='mse')
model.fit(train_ds, epochs=5)
Epoch 1/5
875/875 [==============================] - 4s 2ms/step - loss: 0.1428
Epoch 2/5
875/875 [==============================] - 2s 2ms/step - loss: 0.0841
Epoch 3/5
875/875 [==============================] - 2s 2ms/step - loss: 0.0840
Epoch 4/5
875/875 [==============================] - 2s 2ms/step - loss: 0.0840
Epoch 5/5
875/875 [==============================] - 2s 3ms/step - loss: 0.0840
<keras.callbacks.History at 0x7fc41fa87810>

训练模型with缓存和迭代:

total_data_entries = len(list(raw_dataset.map(lambda x, y: (x, y))))
train_ds = raw_dataset.shuffle(buffer_size=total_data_entries).cache().batch(32).prefetch(tf.data.AUTOTUNE)
_ = list(train_ds.as_numpy_iterator()) # iterate dataset beforehand
inputs = tf.keras.layers.Input((400,))
x = tf.keras.layers.Dense(200, activation='relu', kernel_initializer='normal')(inputs)
x = tf.keras.layers.Dense(100, activation='relu', kernel_initializer='normal')(x)
outputs = tf.keras.layers.Dense(1, kernel_initializer='normal')(x)
model = tf.keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='mse')
model.fit(train_ds, epochs=5)
Epoch 1/5
875/875 [==============================] - 3s 3ms/step - loss: 0.1427
Epoch 2/5
875/875 [==============================] - 2s 2ms/step - loss: 0.0841
Epoch 3/5
875/875 [==============================] - 2s 2ms/step - loss: 0.0840
Epoch 4/5
875/875 [==============================] - 2s 2ms/step - loss: 0.0840
Epoch 5/5
875/875 [==============================] - 2s 2ms/step - loss: 0.0840
<keras.callbacks.History at 0x7fc41ac9c850>

结论:数据集的缓存和先前迭代似乎对训练有影响,但在本例中仅使用了 7000 个文件。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow tf.data.Dataset.cache似乎没有达到预期的效果 的相关文章

  • 保存为 HDF5 的图像未着色

    我目前正在开发一个将文本文件和 jpg 图像转换为 HDF5 格式的程序 用HDFView 3 0打开 似乎图像仅以灰度保存 hdf h5py File Sample h5 img Image open Image jpg data np
  • 使用特定的类/函数预加载 Jupyter Notebook

    我想预加载一个笔记本 其中包含我在另一个文件中定义的特定类 函数 更具体地说 我想用 python 来做到这一点 比如加载一个配置文件 包含所有相关的类 函数 目前 我正在使用 python 生成笔记本并在服务器上自动启动它们 因为不同的
  • 如何用python脚本控制TP LINK路由器

    我想知道是否有一个工具可以让我连接到路由器并关闭它 然后从 python 脚本重新启动它 我知道如果我写 import os os system ssh l root 192 168 2 1 我可以通过 python 连接到我的路由器 但是
  • 如何使用 opencv.omnidir 模块对鱼眼图像进行去扭曲

    我正在尝试使用全向模块 http docs opencv org trunk db dd2 namespacecv 1 1omnidir html用于对鱼眼图像进行扭曲处理Python 我正在尝试适应这一点C 教程 http docs op
  • 安装了 32 位的 Python,显示为 64 位

    我需要运行 32 位版本的 Python 我认为这就是我在我的机器上运行的 因为这是我下载的安装程序 当我重新运行安装程序时 它会将当前安装的 Python 版本称为 Python 3 5 32 位 然而当我跑步时platform arch
  • 需要在python中找到print或printf的源代码[关闭]

    很难说出这里问的是什么 这个问题是含糊的 模糊的 不完整的 过于宽泛的或修辞性的 无法以目前的形式得到合理的回答 如需帮助澄清此问题以便重新打开 访问帮助中心 help reopen questions 我正在做一些我不能完全谈论的事情 我
  • 删除flask中的一对一关系

    我目前正在使用 Flask 开发一个应用程序 并且在删除一对一关系中的项目时遇到了一个大问题 我的模型中有以下结构 class User db Model tablename user user id db Column db String
  • 使用Python请求登录Google帐户

    在多个登录页面上 需要谷歌登录才能继续 我想用requestspython 中的库以便让我自己登录 通常这很容易使用requests库 但是我无法让它工作 我不确定这是否是由于 Google 做出的一些限制 也许我需要使用他们的 API 或
  • 张量流服务错误:参数无效:JSON 对象:没有命名输入

    我正在尝试使用 Amazon Sagemaker 训练模型 并且希望使用 Tensorflow 服务来为其提供服务 为了实现这一目标 我将模型下载到 Tensorflow 服务 docker 并尝试从那里提供服务 Sagemaker 的训练
  • YOLOv8获取预测边界框

    我想将 OpenCV 与 YOLOv8 集成ultralytics 所以我想从模型预测中获取边界框坐标 我该怎么做呢 from ultralytics import YOLO import cv2 model YOLO yolov8n pt
  • 使用 xlrd 打开 BytesIO (xlsx)

    我正在使用 Django 需要读取上传的 xlsx 文件的工作表和单元格 使用 xlrd 应该可以 但因为文件必须保留在内存中并且可能不会保存到我不知道如何继续的位置 本例中的起点是一个带有上传输入和提交按钮的网页 提交后 文件被捕获req
  • 从Python中的字典列表中查找特定值

    我的字典列表中有以下数据 data I versicolor 0 Sepal Length 7 9 I setosa 0 I virginica 1 I versicolor 0 I setosa 1 I virginica 0 Sepal
  • 在Python中检索PostgreSQL数据库的新记录

    在数据库表中 第二列和第三列有数字 将会不断添加新行 每次 每当数据库表中添加新行时 python 都需要不断检查它们 当 sql 表中收到的新行数低于 105 时 python 应打印一条通知消息 警告 数量已降至 105 以下 另一方面
  • javascript 是否有等效的 __repr__ ?

    我最接近Python的东西repr这是 function User name password this name name this password password User prototype toString function r
  • Jupyter Notebook 找不到 Python 模块

    不知道发生了什么 但每当我使用 ipython 氢 原子 或 jupyter 笔记本时都找不到任何已安装的模块 我知道我安装了 pandas 但笔记本说找不到 我应该补充一点 当我正常运行脚本时 python script py 它确实导入
  • 不同编程语言中的浮点数学

    我知道浮点数学充其量可能是丑陋的 但我想知道是否有人可以解释以下怪癖 在大多数编程语言中 我测试了 0 4 到 0 2 的加法会产生轻微的错误 而 0 4 0 1 0 1 则不会产生错误 两者计算不平等的原因是什么 在各自的编程语言中可以采
  • 如何在 Windows 命令行中使用参数运行 Python 脚本

    这是我的蟒蛇hello py script def hello a b print hello and that s your sum sum a b print sum import sys if name main hello sys
  • Django-tables2 列总计

    我正在尝试使用此总结列中的所有值文档 https github com bradleyayers django tables2 blob master docs pages column headers and footers rst 但页
  • 如何应用一个函数 n 次? [关闭]

    Closed 这个问题需要细节或清晰度 help closed questions 目前不接受答案 假设我有一个函数 它接受一个参数并返回相同类型的结果 def increment x return x 1 如何制作高阶函数repeat可以
  • Pandas 每周计算重复值

    我有一个Dataframe包含按周分组的日期和 ID df date id 2022 02 07 1 3 5 4 2022 02 14 2 1 3 2022 02 21 9 10 1 2022 05 16 我想计算每周有多少 id 与上周重

随机推荐

  • 如何将具有前端 SPA 的 Azure CDN 和具有 .Net Core WebApi 的 Azure WebApp 配置到同一自定义域?

    我想拥有https example com https example com作为我设置的 Azure CDN 的自定义域 并且https example com api https example com api作为其余 api 端点来捕
  • 组对组划分

    数据集 date bal 1 31 2013 10 1 31 2013 11 1 31 2013 12 1 31 2013 13 1 31 2013 14 2 28 2013 20 2 28 2013 30 2 28 2013 40 2 2
  • 异步 P/Invoke 调用

    我正在为机器人控制器开发一个包装库 该库主要依赖于 P Invoke 调用 然而 机器人的许多功能 例如归位或移动 需要相当长的时间 并且在运行时会进行线程锁定 所以我想知道如何以异步方式包装功能 这样调用就不会阻塞我的 UI 线程 到目前
  • 如何链接到 rustdoc 中的其他 fns/structs/enums/traits?

    我正在构建一个 Rust 库 并想对其进行一些改进 在 rustdoc 中 我有时想link文档中库的其他部分 例如fns traits or structs 官方语法是什么 As of 铁锈 1 48 https github com r
  • Django 反序列化错误安装 Fixture 时出现问题

    Traceback most recent call last File Users sparshkedia Desktop task venv lib python3 6 site packages django core seriali
  • 如何对这个哈希数组进行分组?

    我有这个哈希数组 name Ben age 18 name David age 19 name Sam age 18 我需要将它们分组age 所以他们最终会变成这样 18 name Ben age 18 name Sam age 18 19
  • NestJs中带有多个参数的@Get DTO

    我正在尝试在 NestJS 中创建一个可通过 GET HTTP 请求访问的控制器操作 该请求接收两个参数 但由于某种原因它们未定义 如何修复它 Get login login Param params LoginUserDto consol
  • 在 Tumblr 上每 3 个帖子添加内容

    我想知道是否有办法在每个页面上的第 3 篇文章之后放置内容 以便我可以渲染一些内容 我在 tumblr 主题 API 上没有找到任何内容 带有 API 的特定帖子 如果您使用 API 来收集 附加帖子 则需要您完成此操作 一个简单的循环 计
  • “我们很抱歉,但有些不对劲。”部署到 Heroku 后

    我制作了一个小型应用程序 用户可以在其中登录 退出 创建等等 我使用 mySQL 作为数据库 并且在本地环境中一切正常 但是当我将其部署到heroku并迁移数据库等之后 heroku版本不起作用 当我追踪日志时我得到了这个 2011 10
  • 仅对单个类禁用 Linq to SQL 类中的自动复数化

    我有一个带有不规则复数的表名 复数与单数相同 有没有办法禁用该单个表的自动复数 Account DB Accounts 同时保留其他表的功能 您需要禁用 LINQ to SQL 设计器的复数表名称 为此 请导航至 工具 gt 选项 gt 数
  • 使用本地 WSDL 文件生成 Metro 客户端

    我之前使用 wsimport 生成了 Metro 客户端 但在这种情况下 WSDL 是通过 https 访问的 我的命令看起来像这样 wsimport https service net services Service wsdl d C
  • Ubuntu:按 Super+L 时不要锁定屏幕 [关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 Whenever I press Super L or Win L on my Ubuntu 14 04 Desktop the scre
  • 按值字母顺序对 Javascript 对象进行排序

    我有一个 JS 对象如下 var obj 00 11 22 33 44 55 AddressB 66 77 88 99 AA BB AddressA 55 44 33 22 11 00 AddressC AA BB CC DD EE FF
  • Apache Kafka 主题名称限制有哪些?

    我刚刚尝试创建一个 Kafka 主题 user created 并在 Kafka 日志中看到此错误 Invalid character in value part of property 我用谷歌搜索发现 在邮件列表中 人们正在谈论弃用 a
  • React Native 后台计时器永远不会停止

    我正在构建一个应用程序 它有一个计时器 可以在计时器处于活动状态时请求地理位置 对于我正在使用的计时器反应本机背景计时器 https github com ocetnik react native background timer 这是可行
  • 调用 sp_rename 时使用变量

    我尝试制作一个存储过程 它将 删除主键 重命名设置主键的列名 创建新的主键 我正在努力解决第 2 点 我正在尝试将列重命名为sp rename将参数传递给存储过程 如下所示 EXEC sp rename SCHEMA TABLE ID Id
  • 为什么我运行 python manage.py runserver 时有两个进程

    wenzhixue 80384 0 4 1 1 2464788 22584 s001 S 10 37AM 0 01 06 usr bin python manage py runserver 0 0 0 0 8000 wenzhixue 8
  • 如何处理大量浮点数据?

    我们有一个二进制文件 其中包含大量float数据 约80MB 我们需要在 Java 应用程序中处理它 数据来自医疗扫描仪 一个文件包含来自一个文件的数据Rotation One Rotation包含 960Views One View包含
  • 为构建器配置 lombok

    我想避免多个构造函数 所以我想使用建造者设计模式 https en wikipedia org wiki Builder pattern 通过使用lombok https projectlombok org setup maven图书馆 它
  • Tensorflow tf.data.Dataset.cache似乎没有达到预期的效果

    我正在尝试按照以下方法提高我的模型训练性能使用 tf data API 获得更好的性能 https www tensorflow org guide data performance指导方针 然而 我观察到使用的性能 cache 如果与没有