具有非常大的 HDF5 文件的 Tensorflow-IO 数据集输入管道

2023-12-23

我有非常大的训练(30Gb)文件。
由于我的可用 RAM 无法容纳所有数据,因此我想批量读取数据。
我看到有 Tensorflow-io 包实施了一种方式 https://www.tensorflow.org/io/api_docs/python/tfio/IODataset#from_hdf5借助该函数,可以通过这种方式将 HDF5 读入 Tensorflowtfio.IODataset.from_hdf5()
然后,自从tf.keras.model.fit()需要一个tf.data.Dataset作为包含样本和目标的输入,我需要将 X 和 Y 压缩在一起,然后使用.batch and .prefetch仅将必要的数据加载到内存中。为了进行测试,我尝试将此方法应用于较小的样本:训练(9Gb)、验证(2.5Gb)和测试(1.2Gb),我知道它们效果很好,因为它们可以放入内存中,并且我得到了很好的结果(70%的准确度和训练文件存储在 HDF5 文件中,分为样本 (X) 和标签 (Y) 文件,如下所示:

X_learn.hdf5  
X_val.hdf5  
X_test.hdf5  
Y_test.hdf5  
Y_learn.hdf5  
Y_val.hdf5

这是我的代码:

BATCH_SIZE = 2048
EPOCHS = 100

# Create an IODataset from a hdf5 file's dataset object  
x_val = tfio.IODataset.from_hdf5(path_hdf5_x_val, dataset='/X_val')
y_val = tfio.IODataset.from_hdf5(path_hdf5_y_val, dataset='/Y_val')
x_test = tfio.IODataset.from_hdf5(path_hdf5_x_test, dataset='/X_test')
y_test = tfio.IODataset.from_hdf5(path_hdf5_y_test, dataset='/Y_test')
x_train = tfio.IODataset.from_hdf5(path_hdf5_x_train, dataset='/X_learn')
y_train = tfio.IODataset.from_hdf5(path_hdf5_y_train, dataset='/Y_learn')
 
# Zip together samples and corresponding labels
train = tf.data.Dataset.zip((x_train,y_train)).batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
test = tf.data.Dataset.zip((x_test,y_test)).batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
val = tf.data.Dataset.zip((x_val,y_val)).batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)

# Build the model
model = build_model()
 
# Compile the model with custom learing rate function for Adam optimizer
model.compile(loss='categorical_crossentropy',
               optimizer=Adam(lr=lr_schedule(0)),
               metrics=['accuracy'])

# Fit model with class_weights calculated before
model.fit(train,
          epochs=EPOCHS,
          class_weight=class_weights_train,
          validation_data=val,
          shuffle=True,
          callbacks=callbacks)

这段代码可以运行,但损失非常高(300+),并且准确度从一开始就下降到 0(0.30 -> 4*e^-5)...我不明白我做错了什么,我错过了吗某物 ?


在这里提供解决方案(答案部分),即使它出现在评论部分中也是为了社区的利益。

代码没有问题,它实际上与数据有关(未正确预处理),因此模型无法很好地学习,这会导致奇怪的损失和准确性。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

具有非常大的 HDF5 文件的 Tensorflow-IO 数据集输入管道 的相关文章

随机推荐

  • 如何绑定到 SwiftUI 中可选的数据

    我很好奇 我们如何指定对属于可选部分的状态数据的绑定 例如 struct NameRecord var name var isFunny false class AppData ObservableObject Published var
  • 使用 SSZipArchive 解压缩文件 - Swift

    我正在尝试使用 SSZipArchive 框架解压缩文件 let unzipper SSZipArchive unzipFileAtPath String document toDestination String documentsUrl
  • PyQt 打印原始 PDF

    假设我有一个test pdf文件在当前目录中 我想使用以下命令将此原始文件发送到打印机PyQt 图形用户界面打印机 下面的Python3代码打印PDF源代码 我不希望 Qt 为我构建 PDF 而只是使用 gui 对话框将其发送到打印机 这应
  • NVidia CUDA 工具包 7.5.27 无法在 OS X 上安装

    下载 CUDA 工具包 DMG 可以工作 但安装程序在选择软件包后失败 并出现神秘的 软件包清单解析错误 错误 使用内部二进制文件从命令行运行安装程序也会以类似的方式失败 var log cuda installer log 处的日志文件基
  • JavaScript 数组问题

    好的 我只是回顾一下 JavaScript 中的一些基本编程原则 我是编程新手 所以请耐心等待 下面是我遇到问题的代码 特别注意数组的字符串组件 var name new Array var sales new Array var tota
  • 如何突出显示列中非空白的重复项?

    我想突出显示 I 列中连接字符串的所有重复项 并在突出显示任何重复项时提供错误消息 但是 该列中有几个空白单元格 我不希望在运行宏时这些单元格显示为重复项 我从这里得到了这个代码 Sub HighlightDuplicateValues D
  • 没有编译器优化的 SSE 内在函数

    我是 SSE 内在函数的新手 并尝试通过它来优化我的代码 这是我的程序 用于计算等于给定值的数组元素 我将代码更改为 SSE 版本 但速度几乎没有改变 我想知道我是否以错误的方式使用SSE 此代码用于不允许我们启用编译器优化选项的分配 无
  • 当从 C# 程序中反序列化 JSON 时,我是否需要使用 JavaScriptSerializer 以外的任何东西?

    NET 中提供了 JavaScriptSerializer 类 System Web Script Serialization 命名空间 在 System Web Extensions dll 中提供 它最初旨在支持 AJAX Web 服务
  • 如何使用通配符设置docker的NO_PROXY

    正如 docker 官方文档中提到的here https docs docker com config daemon systemd configure where the docker daemon listens for connect
  • flatMap API 合约如何将可选输入转换为非可选结果?

    这是 Swift 3 0 2 中 flatMap 的合约 public struct Array
  • 从 Unity 中的 Android Studio 读取意图

    我有一个 Unity 游戏导出到 Android Studio 中 我有一个已保存游戏的列表 其中存储了玩家玩的每个游戏的最后一个场景 基本上存储玩家的进度 从 Unity 到 Android Studio 播放的最后一个场景的编写效果非常
  • Delphi 应用程序的插件系统 - bpl 与 dll?

    我正在编写delphi应用程序 它应该具有加载插件的能力 我使用 JvPluginManager 作为插件系统 管理器 现在 在新的插件向导中 他们说最好使用 bpl 类型插件而不是 dll 插件 这个解决方案与 dll 类型插件相比有什么
  • 增量求解有什么好处?

    如果 pop 完全破坏了上下文 即学到的引理 增量约束求解使用 堆栈 的目的是什么 模式 理由 我想如果我只有 1 个约束 几个 合词 最好进行单个查询 而不是 将单独帧中的合取词堆叠到堆栈上 如果我 有超过 1 个约束并决定使用增量求解
  • 如何使用 Gekko 加快优化速度?

    我的计划是优化家用电池的充电和放电 以最大限度地降低年底的电力成本 每15分钟测量一次家庭用电量 所以我在1天内有96个测量点 我想优化电池 2 天的充电和放电 以便第 1 天考虑到第 2 天的使用情况 我编写了以下代码并且它有效 from
  • new 类名(). 方法名(); VS className ref = new className();

    我遇到了我的同事在一个内部使用的代码eventListner 即 private void someActionPerformed java awt event ActionEvent evt new className methodNam
  • makefile“没有规则来创建目标”错误

    我已经研究这个问题有一段时间了 但仍然不知道出了什么问题 我的 makefile 如下所示 F90 pgf90 NETCDF DIR opt netcdf LBS L NETCDF DIR lib lnetcdff lnetcdf INCL
  • 通过交互和指南修改 ggplot2 中的图例

    df lt data frame Depth c 1 2 3 4 5 6 7 8 Var1 as factor c rep A 4 rep B 4 Var2 as factor c rep c C D 4 Value runif 8 g l
  • Eclipse 给出错误“...不是链接资源的有效位置。”

    当我尝试在 Eclipse 中为构建路径配置添加新的类路径变量 并且我添加的路径是当前工作区是其子目录的目录时 Eclipse 给出错误 C JavaStuff is not a valid location for linked reso
  • WCF DataContract - 标记成员 IsRequired=false

    我有一份合同如下 DataContract public class MyObj DataMember IsRequired true public string StrA get private set DataMember IsRequ
  • 具有非常大的 HDF5 文件的 Tensorflow-IO 数据集输入管道

    我有非常大的训练 30Gb 文件 由于我的可用 RAM 无法容纳所有数据 因此我想批量读取数据 我看到有 Tensorflow io 包实施了一种方式 https www tensorflow org io api docs python