训练神经网络时资源耗尽 - keras

2023-12-11

我有一个包含 65668 个文件的数据集。

我使用 Keras 作为 CNN,这些是我的层:

embedding_layer = Embedding(len(word_index) + 1,
                        EMBEDDING_DIM,
                        weights=[embedding_matrix],
                        input_length=MAX_SEQUENCE_LENGTH,
                        trainable=True)
sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
embedded_sequences = embedding_layer(sequence_input)
x = Conv1D(128, 5, activation='relu')(embedded_sequences)
x = MaxPooling1D(5)(x)
x = Conv1D(256, 5, activation='relu')(x)
x = MaxPooling1D(5)(x)
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
preds = Dense(len(labels_index), activation='softmax')(x)

第一个嵌入层在 GloVE.6B.100d 上进行训练。 拟合数据:

# fitting the data
model.fit(x_train, y_train, validation_data=(x_val, y_val),
      epochs=20, batch_size=128)

The MAX_SEQUENCE_LENGTH是 500。 我正在 GPU 上进行训练,Nvidia GeForce 940MX, 作为堆栈的一部分,我收到以下错误:

资源耗尽:在使用 shape[15318793,100] 分配张量并通过分配器 GPU_0_bfc 在 /job:localhost/replica:0/task:0/device:GPU:0 上键入 float 时出现 OOM

我尝试将批量大小减少到 16,甚至 8,但仍然遇到相同的错误。 问题可能是什么?


问题出在你的Embedding。它需要分配一个大小的矩阵15318793 * 100 * 4 bytes = 5.7 GB这绝对大于GeForce 940 MX记忆。有几种方法可以解决这个问题:

  1. 减少词汇量/语料库大小:尝试采取例如1M 最常用的单词而不是完整的单词集。这将大大减少嵌入矩阵的大小。

  2. 使用发电机代替Embedding: 而不是使用Embedding您可以使用生成器将序列转换为词向量序列。

  3. 使用线性变换Embedding而不是重新训练你的嵌入- 正如你提到的,带有标志trainable=False使您的算法正常工作,您可以将其设置为False并添加:

    Dense(new_embedding_size, activation='linear')(embedding)
    

    基于现有的嵌入来训练新的嵌入。

  4. 更换设备- 如果你有巨大的RAM内存你可以尝试以下策略:

    with tf.device('/cpu:0'):    
        embedding_layer = Embedding(len(word_index) + 1,
            EMBEDDING_DIM,
            weights=[embedding_matrix],
            input_length=MAX_SEQUENCE_LENGTH,
            trainable=True)
        sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
        embedded_sequences = embedding_layer(sequence_input)
    

    在本设计中计算Embedding层将使用CPU and RAM。缺点是在之间转移RAM and GPU可能真的很慢。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

训练神经网络时资源耗尽 - keras 的相关文章

  • 如何通过不规则索引获取子张量?

    我想通过不规则索引获得子张量 这是我的问题 Input tensor 2x8x10x1 Batch x Height x Width x Channel index Height 0 1 4 5 index Width 0 1 4 5 8
  • 如何在arm64主机上运行amd64 docker镜像

    警告 请求的映像平台 linux amd64 与检测到的主机平台 linux arm64 v8 不匹配 并且未请求特定平台 2021 07 28 22 25 06 349222 F tensorflow core platform cpu
  • Keras如何在Relu激活函数中使用max_value

    keras activation py 中定义的 Relu 函数为 def relu x alpha 0 max value None return K relu x alpha alpha max value max value 它有一个
  • TF map_fn 或 while_loop 用于不同形状的张量列表

    我想处理不同形状的张量序列 列表 并输出另一个张量列表 考虑每个时间戳上具有不同隐藏状态大小的 RNN 就像是 输入 tf ones 1 2 2 tf ones 2 2 3 tf ones 3 2 1 输出 tf zeros 1 2 4 t
  • 按相似度对矩阵进行排序

    我有 100 个矩阵 其中每一行对应一个个体 列对应站点 我想通过相似性度量对行进行排序 以便最相似的个体在矩阵中彼此相邻 我使用 k 近邻按行对矩阵进行排序 并将这些排序的矩阵提供给卷积神经网络 我想知道是否还有其他措施可以完成手头的任务
  • sigmoid激活函数可以用来解决Keras中的回归问题吗?

    我已经用 R 实现了简单的神经网络 但这是我第一次用 Keras 实现 所以希望得到一些建议 我在 Keras 中开发了一个神经网络函数来预测汽车销量 数据集可用here https github com allmydatasets dat
  • 使用预训练(Tensorflow)CNN 提取特征

    深度学习已成功应用于多个大型数据集 用于对少数类别 猫 狗 汽车 飞机等 进行分类 其性能优于 SIFT 特征袋 颜色直方图等更简单的描述符 然而 训练这样的网络需要每个类别大量的数据和大量的训练时间 然而 在花时间设计和训练这样一种设备并
  • 关于具有自定义损失的 3 输出 ANN 的加权

    我正在尝试定义一个自定义损失函数 它在回归模型中接收 3 个输出变量 def custom loss y true y pred y true c K cast y true float32 Shape batch size 3 y pre
  • Scipy 稀疏 CSR 矩阵到 TensorFlow SparseTensor - 小批量梯度下降

    我有一个 Scipy 稀疏 CSR 矩阵 它是根据 SVM Light 格式的稀疏 TF IDF 特征矩阵创建的 特征数量巨大且稀疏 所以我必须使用 SparseTensor 否则速度太慢 例如 特征数量为 5 示例文件如下所示 0 4 1
  • 在 R 中使用深度网络和 MNIST 数据读取手写数字第 3 部分

    我尝试编写一个基于深度网络的程序来读取手写数字 我在 Youtube 上找到了一个代码 https www youtube com watch v 5bso 5X7Zu4 https www youtube com watch v 5bso
  • Tensorflow 训练期间 GPU 使用率非常低

    我正在尝试为 10 类图像分类任务训练一个简单的多层感知器 这是 Udacity 深度学习课程作业的一部分 更准确地说 任务是对各种字体呈现的字母进行分类 数据集称为 notMNIST 我最终得到的代码看起来相当简单 但无论如何我在训练期间
  • 在防风草模型上使用 VIP 包计算重要性度量

    我正在尝试使用 vi firm 在防风草中制作的逻辑回归模型上计算特征重要性 对于正则表达式 我将使用 iris 数据集并尝试预测观察结果是否为 setosa iris1 lt iris gt mutate class case when
  • 使用神经网络包进行多项分类

    这个问题应该很简单 但文档没有帮助 我正在使用 R 我必须使用neuralnet多项式分类问题的包 所有示例均针对二项式或线性输出 我可以使用二项式输出进行一些一对一的实现 但我相信我应该能够通过使用 3 个单元作为输出层来做到这一点 其中
  • TensorFlow:Dst 张量未初始化

    The MNIST For ML Beginners当我运行时教程给我一个错误print sess run accuracy feed dict x mnist test images y mnist test labels 其他一切都运行
  • 将tensorflow 2.0 BatchDataset转换为numpy数组

    我有这个代码 train images test images tf keras datasets mnist load data train dataset tf data Dataset from tensor slices train
  • Tensorflow 可变图像输入大小(自动编码器、放大......)

    Edit WARNING不建议使用不同图像大小的图像 因为张量需要具有相同的大小才能实现并行化 我一直在寻找解决方案 了解如何使用不同大小的图像作为神经网络的输入 Numpy 第一个想法是使用numpy 然而 由于每个图像的大小不同 我无法
  • 用 tf.data 替换基于队列的输入管道

    我正在阅读 Ganegedara 的 NLP with Tensorflow 输入pipieline的介绍有以下例子 import tensorflow as tf import numpy as np import os Defining
  • 可以在 TensorFlow 中使用排名相关作为成本函数吗?

    我正在处理偶尔充满异常值的极其嘈杂的数据 因此我主要依靠相关性来衡量我的神经网络的准确性 是否可以明确使用诸如等级相关性 斯皮尔曼相关系数 之类的东西作为我的成本函数 到目前为止 我主要依赖 MSE 作为相关性的代理 我现在面临三个主要障碍
  • Keras conv1d 层参数:过滤器和 kernel_size

    我对 keras 的 conv1d 层中的这两个参数感到非常困惑 https keras io layers convolutional conv1d https keras io layers convolutional conv1d 文
  • 张量流:注册 numpy bfloat16 扩展

    正如我所见 tensorflow 中有 bfloat16 的 numpy 扩展 https github com tensorflow tensorflow blob 24ffe9f729160a095a5cab8f592392018280

随机推荐

  • 如何解绑和重新绑定

    archive click function event do something archive2 unbind click event 我有这个点击功能 我取消了绑定 但是 当我单击某个按钮时 我想再次绑定它 archive bind
  • 如何设置 SQL Server 2005 作业 CmdExec 超时

    我在 SQL Server 2005 中有一个作业设置 其中有一个操作系统 CmdExec 步骤 该步骤调用一个可能需要很长时间才能运行的程序 我发现 如果程序响应时间超过 1 分 40 秒 则该步骤将失败 并显示错误消息 操作已超时 该程
  • 检测类型是否是“映射”

    我想使用它们将 C 容器解析为另一个对象 iterator会员类型 迭代器成员类型指向单一类型 向量 队列等 对象的容器将变成类列表对象 迭代器成员类型指向单一类型对象的容器std pair将变成一个类似地图的物体 我试图编写一个成员函数来
  • Polymer,如何等待 core-ajax 完成后再渲染其他元素?

    更新 以下是针对这种情况的文档 条件模板使用 if 属性有条件地创建模板实例 这个应用程序 plnkr co 应执行以下操作 使用 core ajax 组件从数据库获取project location 本例中为JSON 使用 google
  • 如何在 Eclipse 中打印到 textArea 而不是控制台?

    我目前有一个程序 可以以各种方式将文本行打印到屏幕上 例如 System out println 语句 并且 for 循环将数组中的所有元素打印到屏幕上 我现在在一个单独的类中向该程序添加一个 GUI 我的问题是我想将打印到 Eclipse
  • 带坐标的平铺网格

    我正在尝试创建一个可以用 with 或 height 指定的网格 即 10 个框宽 x 20 个高 我已经创建了一个创建网格的脚本 但我想以一种可以以与我的方式不同的宽度和高度创建网格的方式进行制作 它当前创建一个宽度与高度相等的网格 并且
  • CryptoStream 没有像预期那样刷新

    我正在处理的 C NET Framework 4 5 代码应该允许我通过加密流将文本传输到另一个程序 我创建了两个简单的程序来演示我的问题 EncryptionTestA 是服务器 并且应该首先运行 EncryptionTestB 是客户端
  • 使用 Boost gzip_decompressor 解压缩内存中的数据

    我正在尝试使用 Boost 解压缩内存中的二进制数据gzip decompressor From 这个答案 我改编了以下代码 vector
  • 使用 Google 幻灯片中的应用程序脚本将 pageElements 置于前面或后面

    堆叠顺序由它们插入幻灯片的顺序决定 但是 幻灯片中的某些页面元素仍然隐藏 有没有办法使用应用程序脚本更改 Google 幻灯片中对象的顺序 这个解决方法怎么样 我和你经历过同样的情况 当时 我已经使用此解决方法移动了该元素 我认为针对这种情
  • 重新加载不同的表视图单元格后,NSOutlineView 行不再可以通过“Return”键编辑

    我遇到了最奇怪的问题NSOutlineView 一切都在故事板中设置 即大纲视图和两个NSTableCellViews 两个单元格视图基本相同 只有一个显示图标 另一个不显示 我可以通过按开始编辑项目 行 Return键 即NSTextFi
  • Dojo 实习生设置 firefox 配置文件名称

    您好 我正在尝试在环境设置中设置 Firefox 配置文件名称intern配置文件 我已经尝试过 environments browserName firefox firefox profile default firefox profil
  • css z-index 嵌套元素的问题

    我想在 z 平面上订购 3 个 HTML 元素 bank width 200px height 200px background color grey position absolute z index 100 transform tran
  • 表单关闭后从特定上下文运行代码?

    我想在此处创建的表单关闭后在此上下文中运行一些代码 Form1 Form1 new Form1 Form1 Show lt After this closes I want to run code from this context usi
  • 如何以编程方式更改第三台显示器

    当我使用笔记本电脑时 我使用 3 个显示器 笔记本电脑显示屏 第二台显示器 通过 VGA 连接 电视 通过 HDMI 连接 我的显卡不支持 3 个显示器 所以我不断地从 2 个显示器切换到 3 个显示器 当我在计算机上时 我使用第二个显示器
  • 在 Electron 中找不到模块

    我目前正在与 Babylon 一起开发 Electron 我发现这个仓库我基本上将其用作我自己项目的样板 一切都运行良好 直到我尝试添加jquery pep js用于其他需求 我一直犯这个错误 未捕获的错误 找不到模块 jquery pep
  • 如何在 JavaScript 中将麦克风静音

    所以我正在创建一个视频通话网络应用程序 我想在其中打开 关闭麦克风 打开 关闭视频功能 navigator mediaDevices getUserMedia video true audio true then stream gt con
  • 为什么在innerHTML 中使用Array#map 输出中的额外逗号?

    之前的帖子已经提到了如何toString 方法将在映射的每个项目之间放置逗号 并且可以通过使用来解决这个问题join 下面 尝试 2 在显示的对象之间添加了逗号 而尝试 1 则没有 为什么是这样 如何修改尝试 2 使其输出复制尝试 1 va
  • 用户定义类型作为 PostgreSQL 函数中的输入参数

    您好 我正在创建一个用于插入元数据的过程 我创建了类型 并在另一种类型中包含了一种类型 并且在过程中我对其进行迭代以获取值 由于我是 PostgreSQL 的新手 任何人都可以帮助我如何调用该过程 输入参数为类型 Create Type F
  • Netbeans 7.1.2 - 无法添加 glassfish 服务器 3.1.2

    我从下载 glassfish 服务器http glassfish java net downloads 3 1 2 2 final html并单独安装 现在我正在尝试将其添加到 Netbeans 中 但这不起作用 我做了以下步骤 以管理员身
  • 训练神经网络时资源耗尽 - keras

    我有一个包含 65668 个文件的数据集 我使用 Keras 作为 CNN 这些是我的层 embedding layer Embedding len word index 1 EMBEDDING DIM weights embedding