keras 对于大数据的训练,无法一次性载入内存,使用迭代器

2023-11-05

说明:我是在keras的官方demo上进行修改https://github.com/fchollet/keras/blob/master/examples/imdb_cnn.py

1、几点说明,从文件中读入数据,会降低GPU的使用率,如果能够直接将数据载入内存,GPU的使用率会比较高。下面进行对比:

全部数据载入内存,GPU的使用率:

使用队列,边读数据边进行训练:


结论:全部载入内存,GPU的使用率可以达到82%,如果边载入数据边训练,只能达到48%


2、keras 使用迭代器来实现大数据的训练,其简单的思想就是,使用迭代器从文件中去顺序读取数据。所以,自己的训练数据一定要先随机打散。因为,我们的迭代器也是每次顺序读取一个batch_size的数据进行训练。

举例如下:数据如下,前400维是特征,后一维是label


keras 官方的demo 如下:

def generate_arrays_from_file(path):
    while 1:
    f = open(path)
    for line in f:
        # create Numpy arrays of input data
        # and labels, from each line in the file
        x, y = process_line(line)
        yield (x, y)
    f.close()

model.fit_generator(generate_arrays_from_file('/my_file.txt'),
        samples_per_epoch=10000, nb_epoch=10)
说明:官方的demo还是有瑕疵的,没有实现batch_size,该demo每次只能提取一个样本。我针对上述的数据集,实现的batch_size数据提取的迭代器,代码如下:

def process_line(line):
    tmp = [int(val) for val in line.strip().split(',')]
    x = np.array(tmp[:-1])
    y = np.array(tmp[-1:])
    return x,y

def generate_arrays_from_file(path,batch_size):
    while 1:
        f = open(path)
        cnt = 0
        X =[]
        Y =[]
        for line in f:
            # create Numpy arrays of input data
            # and labels, from each line in the file
            x, y = process_line(line)
            X.append(x)
            Y.append(y)
            cnt += 1
            if cnt==batch_size:
                cnt = 0
                yield (np.array(X), np.array(Y))
                X = []
                Y = []
    f.close()

训练时候的代码如下:

model.fit_generator(generate_arrays_from_file('./train',batch_size=batch_size),
        samples_per_epoch=25024,nb_epoch=nb_epoch,validation_data=(X_test, y_test),max_q_size=1000,verbose=1,nb_worker=1)

3、关于samples_per_epoch的说明:

我的训练数据,train只有25000行,batch_size=32。照理说samples_per_epoch=32,但是会有警告.UserWarning: Epoch comprised more than `samples_per_epoch` samples, which might affect learning results


说明:这个出错的原因是train的数目/batch_size不是整数。可以将samples_per_epoch = ceil(train_num/batch_size) *batch_size.设置完的结果为88.72%:


keras的demo使用的方法是将全部数据载入进来训练:


demo的结果为88.86%,所以,该数据读取的方式基本没问题。但是,一定要将数据先进行打乱。如果能全部载入内存,就全部载入内存,速度会快不少



本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

keras 对于大数据的训练,无法一次性载入内存,使用迭代器 的相关文章

  • 如何将 model.summary() 保存到 Keras 中的文件?

    有model summary 方法 https keras io models about keras models 在喀拉斯 它将表打印到标准输出 是否可以将其保存到文件中 如果您想要摘要的格式 您可以传递print功能为model su
  • Keras Maxpooling2d 层给出 ValueError

    我正在尝试在 keras 中复制 VGG16 模型 以下是我的代码 model Sequential model add ZeroPadding2D 1 1 input shape 3 224 224 model add Convoluti
  • Keras AttributeError:“顺序”对象没有属性“predict_classes”

    我试图按照本指南找到模型性能指标 F1 分数 准确性 召回率 https machinelearningmastery com how to calculate precision recall f1 and more for deep l
  • 如何防止 Keras 在训练期间计算指标

    我正在使用 Tensorflow Keras 2 4 1 并且有一个 无监督的 自定义指标 它将我的几个模型输入作为参数 例如 model build model returns a tf keras Model object my met
  • 使用 theano 进行多处理

    我正在尝试将 theano 与 cpu 多处理和神经网络库 Keras 结合使用 I use device gpu标记并加载 keras 模型 然后 为了提取超过一百万张图像的特征 我使用多处理池 该函数看起来像这样 from keras
  • Keras LSTM 密集层多维输入

    我正在尝试创建一个 keras LSTM 来预测时间序列 我的 x train 形状像 3000 15 10 示例 时间步长 特征 y train 形状像 3000 15 1 我正在尝试构建一个多对多模型 每个序列 10 个输入特征产生 1
  • Keras 序列模型中的数据增强层

    我正在尝试将数据增强作为一个层添加到模型中 但我遇到了我认为是形状问题 我也尝试在增强层中指定输入形状 当我取出data augmentation模型中的图层运行良好 preprocessing RandomFlip horizontal
  • 在不丢失基数信息的情况下对 TensorFlow 数据集进行窗口处理?

    tf data Dataset window返回一个新的数据集 其元素是数据集 这些嵌套数据集的元素是所需大小的窗口 如果您有一个数据集 例如 Dataset range 10 并想要一个像这样的窗口数据集 0 1 2 1 2 3 7 8
  • Native TF 与 Keras TF 性能比较

    我使用本机和后端张量流创建了完全相同的网络 但在使用多个不同参数进行了多个小时的测试后 仍然无法弄清楚为什么 keras 优于本机张量流并产生更好 稍微但更好 的结果 Keras 是否实现了不同的权重初始化方法 或者执行除 tf train
  • 批量归一化,是还是否?

    我使用 Tensorflow 1 14 0 和 Keras 2 2 4 以下代码实现了一个简单的神经网络 import numpy as np np random seed 1 import random random seed 2 imp
  • 为什么不使用均方误差来解决分类问题?

    我正在尝试使用 LSTM 解决一个简单的二元分类问题 我正在尝试找出网络的正确损失函数 问题是 当我使用二元交叉熵作为损失函数时 与使用均方误差 MSE 函数相比 训练和测试的损失值相对较高 经过研究 我发现二元交叉熵应该用于分类问题 MS
  • ValueError:张量:(...)不是该图的元素

    我正在使用 keras 的预训练模型 在尝试获取预测时出现错误 我在烧瓶服务器中有以下代码 from NeuralNetwork import app route uploadMultipleImages methods POST def
  • 在相同任务上,Keras 比 TensorFlow 慢

    我正在使用 Python 运行斩首 DCNN 本例中为 Inception V3 来获取图像特征 我使用的是 Anaconda Py3 6 和 Windows7 使用 TensorFlow 时 我将会话保存在变量中 感谢 jdehesa 并
  • 无法获取未知等级的 Shape 长度

    我有一个神经网络 来自tf data数据生成器和tf keras模型 如下 简化版本 因为太长 dataset A tf data Dataset反对与next x方法调用get next为了x train迭代器和next y方法调用get
  • 如何使用 Tensorflow-GPU 和 Keras 修复低易失性 GPU-Util?

    我有一台 4 GPU 机器 在上面运行带有 Keras 的 Tensorflow GPU 我的一些分类问题需要几个小时才能完成 nvidia smi returns Volatile GPU Util which never exceeds
  • Keras model.predict 函数给出输入形状错误

    我已经在 Tensorflow 中实现了通用句子编码器 现在我正在尝试预测句子的类概率 我也将字符串转换为数组 Code if model model type universal classifier basic class probs
  • 使用 Keras np_utils.to_categorical 的问题

    我正在尝试将整数的 one hot 向量数组制作为 keras 将能够使用的 one hot 向量数组来拟合我的模型 这是代码的相关部分 Y train np hstack np asarray dataframe output vecto
  • MultiHeadAttention Attention_mask [Keras、Tensorflow] 示例

    我正在努力掩盖 MultiHeadAttention 层的输入 我正在使用 Keras 文档中的 Transformer Block 进行自我关注 到目前为止 我在网上找不到任何示例代码 如果有人能给我一个代码片段 我将不胜感激 变压器块来
  • 将 tf.contrib.layers.xavier_initializer() 更改为 2.0.0

    我该如何改变 tf contrib layers xavier initializer tf 版本 gt 2 0 0 所有代码 W1 tf get variable W1 shape self input size h size initi
  • 在按顺序读取的多个特征文件上训练 Keras 模型以节省内存

    当我尝试读取大量功能文件时 我遇到了内存问题 见下文 我想我应该分割训练文件并按顺序读取它们 做到这一点的最佳方法是什么 x train np load path features x train npy y train np load p

随机推荐

  • radare2 使用记录

    radare2 使用记录 编译 调试分析 数据结构 rasm disasm analop 反汇编 cs disasm libarch 编译 radare2 UNIX like reverse engineering framework an
  • VSCode 无法跳转C语言函数定义和变量定义的解决方案(本地端+远程服务器端)

    文章目录 前言 1 给本地端安装 C C 插件 2 给远程服务器端安装 C C 插件 小结 前言 初次使用 VSCode 编辑代码时 估计有不少小伙伴遇到过点击函数或变量无法跳转到定义处 左侧大纲栏里也没有任何内容的情况 这是缺少 C C
  • Vue项目运行报错:operty or method “xxx“ is not defined on the instance but referenced during render.

    报错原因 属性或方法 xxx未在实例上定义 但在渲染过程中被引用 解决方法 定义这个属性或者方法 1 只渲染了 没有定义 2 定义属性或方法 注意 如果定义了还是报这个错误 那么请一定检查定义的位置是不是正确的 博主偶尔也会出现这个问题 但
  • Spring cloud alibaba sentinel 实战

    Sentinel 分布式系统流量防卫兵 一 简介 二 特性 三 概念 四 安装 4 1 本地安装 4 2 docker 安装 五 实例 5 1 启动sentinel 5 2 模块配置 六 持久化配置 七 注意 6 1 SentinelRes
  • 奥拉星登陆显示网络或服务器,《奥拉星手游》进不去游戏怎么回事 无法进入游戏解决方法分享...

    导 读 奥拉星手游进不去游戏怎么办 很多玩家都卡顿在游戏外面了 那么遇到这个问题如何解决呢 下面九游小编为大家介绍奥拉星无法进入游戏解决方法 奥拉星无法进入游戏解决方法 目前测试服刚刚开服 人数在一时 奥拉星手游进不去游戏怎么办 很多玩家都
  • 二叉树(构造篇)

    二叉树 纲领篇 先复述一下前文总结的二叉树解题总纲 是否可以通过遍历一遍二叉树得到答案 如果可以 用一个 traverse 函数配合外部变量来实现 这叫 遍历 的思维模式 是否可以定义一个递归函数 通过子问题 子树 的答案推导出原问题的答案
  • 华为OD机试 - 判断一组不等式是否满足约束并输出最大差(Java)

    题目描述 给定一组不等式 判断是否成立并输出不等式的最大差 输出浮点数的整数部分 要求 不等式系数为 double类型 是一个二维数组 不等式的变量为 int类型 是一维数组 不等式的目标值为 double类型 是一维数组 不等式约束为字符
  • sublime添加直接运行语言的方法

    Tools Build system New Build System 添加新的编译文件 添加lua编译环境 cmd usr local bin lua file file regex lua t 0 9 0 9 selector sour
  • js检索关键字

    var i str indexOf 关键字 formi 查找str中formi的位置之后的下一个关键字的下标值 如果省略第二个关键字 则默认从0开始查找 如果没有找到 则返回 1 var i str lastIndexOf 关键字 form
  • title=“{{item.id}}“: Interpolation inside attributes has been removed. Use v-bind or the colon short

    title item id Interpolation inside attributes has been removed Use v bind or the colon shorthand instead v for列表渲染中给a ca
  • python中codecs模块_python自然语言编码转换模块codecs介绍

    python对多国语言的处理是支持的很好的 它可以处理现在任意编码的字符 这里深入的研究一下python对多种不同语言的处理 有一点需要清楚的是 当python要做编码转换的时候 会借助于内部的编码 转换过程是这样的 原有编码 gt 内部编
  • Linux环境下安装JDK

    安装jdk有两种方法 手动安装 yum安装 yum安装如下 1 查询要安装jdk的版本 yum y list java 2 安装jdk1 8 yum install y java 1 8 0 openjdk x86 64 3 查询jdk版本
  • Windows 系统中安装 MySQL 5.6 zip 步骤并配置 root 密码

    说明 最早我安装 MySQL 还是使用安装版的进行安装 但是作为一名程序员 MySQL 公司既然提供了 zip 版本的安装 我们当然要使用这种绿色的安装方式 MySQL 5 6 版本和 5 7 版本的安装步骤有很大不同 本文记录了 5 6
  • ultraiso制作u盘启动盘教程图文详解纯净-U盘启动教程

    制作u盘启动盘用软碟通ultraiso轻松制作纯净windows7 u盘装系统 网友们除了知道的u大师u盘启动盘制作工具 u启动 u深度 老毛桃 大白菜u盘启动盘制作工具外 还有量产 fbinstTool 我这再介绍一种u盘启动盘的制作方式
  • maven子依赖版本覆盖父依赖

    比如父依赖定义了jaskson version为2 13 3 在
  • 01-Java基础-变量

    一 变量介绍 变量就是向操作系统申请内存来存储值 也就是说 当创建变量的时候 需要在内存中申请空间 内存管理系统根据变量的类型为变量分配存储空间 分配的空间只能用来储存该类型数据 简单理解 类似数学中的 设 x 1 在程序中就表示声明了一个
  • 移动级处理芯片岁末盘点

    时间过得真快 不知不觉间又到了年关 这就说明一年一度做盘点汇总的时候也要到了 作为即将踏入这个科技行业快有三个年头的媒体人 笔者也这在这段时间内跟随新兴的移动互联网市场一起成长着 同时也看尽了这三年来行业里无情的变迁 感叹身在同一个行业里厂
  • 层次分析法(多准则决策方法)

    这是介于定量分析与定性分析的一种方法 运用层次分析法建模 大体上可按下面四个步骤进行 建立递阶层次结构模型 构造出各层次中的所有判断矩阵 层次单排序及一致性检验 层次总排序及一致性检验 建立递阶层次结构模型 每一层次中各元素所支配的元素一般
  • 笔试题13:采用UDP协议,编写一个简单发送字符串的程序(源码)

    UDP协议是一种无须建立连接的网络通信协议 采用Java来编写 一般有以下几个步骤 包括接收端和发送端 1 创建数据Socket 指定一个端口号 2 对于接收消息的一端来说 提供一个byte数组进行数据的存储 而对于发送消息一端 除此之外还
  • keras 对于大数据的训练,无法一次性载入内存,使用迭代器

    说明 我是在keras的官方demo上进行修改https github com fchollet keras blob master examples imdb cnn py 1 几点说明 从文件中读入数据 会降低GPU的使用率 如果能够直