如何在Keras中使用数据生成器(data generators)的详细示例

2023-05-16

目录

  • 动机
  • 讲解
    • 以前的情况
    • 小提示
      • 数据产生器
      • Keras脚本
  • 可运行实例
    • 结论

动机

您是否曾经不得不加载一个非常消耗内存的数据集,以至于希望魔术能够无缝地解决这一问题?大型数据集正日益成为我们生活的一部分,因为我们能够利用数量不断增长的数据。

我们必须谨记,在某些情况下,即使是最先进的配置也没有足够的内存空间来像以前那样处理数据。这就是为什么我们需要找到其他有效地完成该任务的方法的原因。在此博客文章中,我们将向您展示如何实时在多个内核上生成数据集并将其立即馈送到您的深度学习模型。

本教程中使用的框架是Python的高级软件包Keras提供的框架,可以在TensorFlow或Theano的GPU安装之上使用。

讲解

以前的情况

在阅读本文之前,您的Keras脚本可能看起来像这样:

import numpy as np
from keras.models import Sequential

# Load entire dataset
X, y = np.load('some_training_set_with_labels.npy')

# Design model
model = Sequential()
[...] # Your architecture
model.compile()

# Train model on your dataset
model.fit(x=X, y=y)

本文全部涉及更改一次加载整个数据集的方式。确实,此任务可能会引起问题,因为所有训练样本可能无法同时放入内存中。

为了做到这一点,让我们深入研究如何构建适合此情况的数据生成器的逐步方法。顺便说一句,以下代码是用于您自己的项目的良好框架;您可以复制/粘贴以下代码,并相应地填充空白。

小提示

在开始之前,我们先看一些组织技巧,这些技巧在处理大型数据集时特别有用。

ID使用Python字符串来标识给定的数据集样本。跟踪样本及其标签的一种好方法是采用以下框架:

创建一个字典partition收集索引数据:

  • partition[‘train’]:train ID列表
  • partition[‘validation’]:validation ID列表
    创建一个字典labels,其中每个ID数据集的相关标签由给出labels[ID]

例如,假定我们的训练集包含id-1id-2id-3与相应的标签012,验证集合id-4与标签1。在这种情况下,Python的变量partitionlabels看起来像这样:

>>> partition
{'train': ['id-1', 'id-2', 'id-3'], 'validation': ['id-4']}
>>> labels
{'id-1': 0, 'id-2': 1, 'id-3': 2, 'id-4': 1}

此外,出于模块化的考虑,我们将在单独的文件中编写Keras代码和自定义类,以便您的文件夹看起来像

folder/
├── my_classes.py
├── keras_script.py
└── data/

假设data/是包含数据集的文件夹。

最后,需要注意的是,本教程中的代码是针对通用最少的,因此您可以轻松地将其调整为自己的数据集。

数据产生器

现在,让我们详细介绍如何设置Python类DataGenerator,该类将用于将实时数据馈入Keras模型。

首先,让我们编写该类的初始化函数。我们使后者继承了属性,keras.utils.Sequence以便可以利用诸如多重处理之类的出色功能。

def __init__(self, list_IDs, labels, batch_size=32, dim=(32,32,32), n_channels=1,
             n_classes=10, shuffle=True):
    'Initialization'
    self.dim = dim
    self.batch_size = batch_size
    self.labels = labels
    self.list_IDs = list_IDs
    self.n_channels = n_channels
    self.n_classes = n_classes
    self.shuffle = shuffle
    self.on_epoch_end()

我们将有关数据的相关信息作为参数,例如维度大小(例如,长度为32的卷dim=(32,32,32)),通道数,类数,批处理大小,或决定是否要在生成时改组数据。我们还存储重要信息,例如标签以及我们希望在每次通过时生成的ID列表。

在此,该方法on_epoch_end在每个时期的开始和结束时都会触发一次。如果shuffle参数设置为True,则每次通过时我们都会获得新的探索顺序(否则,只需保持线性探索方案即可)。

def on_epoch_end(self):
  'Updates indexes after each epoch'
  self.indexes = np.arange(len(self.list_IDs))
  if self.shuffle == True:
      np.random.shuffle(self.indexes)

调整将示例 fed to 分类器的顺序很有帮助,这样,各个时期之间的批处理看起来就不会相似。这样做最终将使我们的模型更强大。

生成过程最核心的另一种方法是完成最关键的工作:生成一批数据。调用负责此任务的私有方法__data_generation,并将目标批处理的ID列表作为参数。

def __data_generation(self, list_IDs_temp):
  'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
  # Initialization
  X = np.empty((self.batch_size, *self.dim, self.n_channels))
  y = np.empty((self.batch_size), dtype=int)

  # Generate data
  for i, ID in enumerate(list_IDs_temp):
      # Store sample
      X[i,] = np.load('data/' + ID + '.npy')

      # Store class
      y[i] = self.labels[ID]

  return X, keras.utils.to_categorical(y, num_classes=self.n_classes)

在数据生成期间,此代码从其对应的文件中读取每个示例的NumPy数组ID.npy。由于我们的代码是多核友好的,因此请注意,您可以执行更复杂的操作(例如,从源文件进行计算),而不必担心数据生成成为训练过程中的瓶颈。

另外,请注意,我们使用Keras keras.utils.to_categorical函数将存储的数字标签转换y[0 0 1 0 0 0]适合分类的二进制形式(例如,在6类问题中,第三个标签对应于)。

现在是我们将所有这些组件组合在一起的部分。每个调用都请求一个介于0和批次总数之间的批次索引,其中在__len__方法中指定了后者。

def __len__(self):
  'Denotes the number of batches per epoch'
  return int(np.floor(len(self.list_IDs) / self.batch_size))

通常的做法是将此值设置为
⌊ #  samples batch size ⌋ \biggl\lfloor\frac{\#\textrm{ samples}}{\textrm{batch size}}\biggr\rfloor batch size# samples
这样该模型每个时期最多可以看到一次训练样本。

现在,当调用与给定索引相对应的批处理时,生成器将执行__getitem__生成该方法的方法。

def __getitem__(self, index):
  'Generate one batch of data'
  # Generate indexes of the batch
  indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

  # Find list of IDs
  list_IDs_temp = [self.list_IDs[k] for k in indexes]

  # Generate data
  X, y = self.__data_generation(list_IDs_temp)

  return X, y

下面显示了与我们在本节中描述的步骤相对应的完整代码。

import numpy as np
import keras

class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, labels, batch_size=32, dim=(32,32,32), n_channels=1,
                 n_classes=10, shuffle=True):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.labels = labels
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size), dtype=int)

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            X[i,] = np.load('data/' + ID + '.npy')

            # Store class
            y[i] = self.labels[ID]

        return X, keras.utils.to_categorical(y, num_classes=self.n_classes)

Keras脚本

现在,我们必须相应地修改Keras脚本,以便它接受我们刚刚创建的生成器。

import numpy as np

from keras.models import Sequential
from my_classes import DataGenerator

# Parameters
params = {'dim': (32,32,32),
          'batch_size': 64,
          'n_classes': 6,
          'n_channels': 1,
          'shuffle': True}

# Datasets
partition = # IDs
labels = # Labels

# Generators
training_generator = DataGenerator(partition['train'], labels, **params)
validation_generator = DataGenerator(partition['validation'], labels, **params)

# Design model
model = Sequential()
[...] # Architecture
model.compile()

# Train model on dataset
model.fit_generator(generator=training_generator,
                    validation_data=validation_generator,
                    use_multiprocessing=True,
                    workers=6)

正如你所看到的,我们从所谓modelfit_generator方法,而不是fit,我们刚刚给我们的训练发生器的参数之一。Keras负责其余的工作!

请注意,我们的实现允许使用的multiprocessing参数fit_generator,其中指定的线程数是workers并行生成批处理的线程数。足够多的工作人员可以确保有效地管理CPU计算,瓶颈确实是神经网络在GPU上进行的向前和向后操作(而不是数据生成)。

可运行实例

import numpy as np
import keras
from keras.layers import Dense


class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'

    def __init__(self, train_data, test_data, batch_size=32, n_channels=1, shuffle=True):
        'Initialization'
        self.train_data = train_data
        self.test_data = test_data
        self.batch_size = batch_size
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.train_data) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]

        # Generate data
        # X, y = self.__data_generation(train_data_temp)

        X = self.train_data[indexes]
        y = self.test_data[indexes]

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.train_data))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)


def gen_model():

    model = keras.models.Sequential()

    model.add(Dense(units=1, use_bias=False, input_shape=(1,)))  # 仅有的1个权重在这里

    return model


if __name__ == '__main__':

    # 数据比较简单,用 CPU 即可
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    # Parameters
    params = {'batch_size': 10,
              'n_channels': 1,
              'shuffle': False}

    # Datasets
    # partition =  # IDs
    # labels =  # Labels
    x = {}
    x['train'] = np.arange(100, dtype='int32')
    x['validation'] = np.arange(100, 120, dtype='int32')
    y = {}
    y['train'] = -x['train']
    y['test'] = -x['train']


    # Generators
    training_generator = DataGenerator(x['train'], y['train'], **params)
    # validation_generator = DataGenerator(x['validation'], y['test'], **params)

    # Design model
    model = gen_model()
    model.compile(loss='mse', optimizer='adam')

    # model.fit(x['train'], y['train'], epochs=1000, batch_size=10, verbose=2)

    # Train model on dataset
    model.fit_generator(generator=training_generator,
                        # validation_data=validation_generator,
                        epochs=2000,
                        verbose=2,
                        workers=6,
                        )

    print(model.layers[0].get_weights())

结论

就是这个!您现在可以使用以下命令运行Keras脚本

python3 keras_script.py

和你会看到,在训练阶段期间,数据并行地由CPU生成,然后直接喂入到GPU

你可以找到这个战略上的一个具体的例子应用的完整例子GitHub上,其中的代码数据的生成以及对Keras脚本可用。


ref:
姊妹文章:Keras数据生成器以及如何使用它们

A detailed example of how to use data generators with Keras

Artificial Intelligence cheatsheets

Machine Learning cheatsheets

Ddeep Learning cheatsheets

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在Keras中使用数据生成器(data generators)的详细示例 的相关文章

  • Lodash 核心 lodash

    baseCreate Object create ptoto propertiesObject 不能以名称而去定义该方法的作用 它并不只是为了创建一个对象 其可以理解为 34 创建一个继承了指定对象的对象 34 并且创建后的对象是不存在原型
  • MySql之索引

    索引 1 gt 什么是索引 xff1f 2 gt MySql的查询方式3 gt 索引的实现原理4 gt 索引应用在什么情况下5 gt 索引的创建 删除6 gt 索引的常见失败情况第一种情况即使添加了索引也不走索引的原因 第二种情况第三种情况
  • VNC远程管理配置

    其实配置VNC很简单 xff0c 只要运行vncserver就好了 运行完毕后 xff0c 它会在家目录生成 vnc目录 xff0c 里面最重要的一个文件是Xstartup 然后你可以使用vncviewer yourremotehost i
  • uCOS/FreeRTOS任务创建的两种模式

    在我们使用uCOS FreeRTOS编写代码时 xff0c 首先要面临的一个问题是怎样创建任务并启动整个系统 一般来说 xff0c 我们会有两种不同的方式 这两种方式不仅适用于uCOS FreeRTOS xff0c 同时也适用于其它RTOS
  • 树莓派 —— 配置Windows通过VNC连接树莓派

    VNC简介 VNC Virtual Network Console 是虚拟网络控制台的缩写 它 是一款优秀的远程控制工具软件 xff0c 由著名的 AT amp T 的欧洲研究实验室开发的 VNC 是在基于 UNIX 和 Linux 操作系
  • 测试UDP端口连通性

    测试UDP端口连通性 Linux使用netcat测试udp端口Centos7安装netcat 依赖epel源 netcat常用参数使用netcat创建TCP客户端和服务器使用netcat创建UDP客户端和服务器 windows使用netca
  • Centos7安装kvm服务器

    Centos7安装kvm服务器 什么是kvmvirt manager及相关软件简介virt manger架构及原理KVMQEMULibvirt 检查硬件是否支持kvm虚拟化启用嵌套虚拟化 可选 检查是否启用嵌套虚拟化热生效嵌套虚拟化 临时启
  • IP子网划分与计算

    IP地址分为5类 A类 xff1a 0 0 0 0 127 255 255 255 8bit B类 xff1a 128 0 0 0 191 255 255 255 16bit C类 xff1a 192 0 0 0 223 255 255 2
  • 透镜成像原理,眼球成像原理,小孔成像原理

    透镜成像规律总结 规律1 xff1a 当物距大于 2 倍 焦距 时 xff0c 则像距在1 倍焦距和 2 倍焦距之间 xff0c 成倒立 缩小的实像 此时像距小于物距 xff0c 像比物小 xff0c 物像异侧 应用 xff1a 照相机 摄
  • Ubuntu语言支持为灰色修复方法

    在Ubuntu12 04中 xff0c 在下不知为何将语言支持中应用到整个系统和添加语言这2个按弄成了灰色 xff0c 导致ibus 不能输入中文 xff0c 修复方法如下 xff1a 1 启动terminal xff0c 输入如下命令 x
  • torch中Tensor和numpy相互转化

    Numpy转为Tensor 使用torch from numpy 如 import torch B 61 torch from numpy A Tensor转为Numpy 使用data numpy 如 import torch C 61 B
  • LXD/LXC raw.idmap 使用方法和作用

    官方文档 xff1a https github com lxc lxd blob master doc userns idmap md 设置和取消设置 raw idmap xff0c 都需要重启容器才能应用 并且在重启容器时 xff0c 会
  • LXD/LXC 奇怪的重启断网问题解决。

    2023 4 10 日更新 搞 Debian 系统的 LXD 时 xff0c 发现了 Debian 对 lxd 的已知问题跟踪 似乎是 Docker 的原因 已知问题 跟踪链接 xff1a https wiki debian org LXD
  • SSH连接问题:连不上&不能免密登录

    一 连不上 ssh username 64 ip 报错 xff1a ssh connect to host lt ip4地址 gt port 22 Connection timed out 首先排查 xff0c 排查步骤 xff1a 1 p
  • Java经典面试题总结

    本文分为十九个模块 xff0c 分别是 xff1a Java 基础 容器 多线程 反射 对象拷贝 Java Web 异常 网络 设计模式 Spring Spring MVC Spring Boot Spring Cloud Hibernat
  • 单例模式常见场景

    单例模式 Singleton 也叫单态模式 xff0c 是设计模式中最为简单的一种模式 xff0c 甚至有些模式大师都不称其为模式 xff0c 称其为一种实现技巧 xff0c 因为设计模式讲究对象之间的关系的抽象 xff0c 而单例模式只有
  • Java 基础系列(十) --- 什么是向上转型和向下转型

    1 向上转型 1 1 为何叫向上转型 在面向对象程序设计中 针对一些复杂的场景 我们通常画一个UML图来表示各个类之间的关系 通常父类画在子类的上方 因此我们就称之为 34 向上转型 34 表示往父类的方向转 向上转型发生的时机 直接赋值
  • Python函数式编程——map()、reduce()

    原文链接 提起map和reduce想必大家并不陌生 xff0c Google公司2003年提出了一个名为MapReduce的编程模型 1 xff0c 用于处理大规模海量数据 xff0c 并在之后广泛的应用于Google的各项应用中 xff0
  • [Linux] CentOS8 升级

    A CentOS8 0升级到8 5的方法 由于CentOS8已经仅仅维护Stream xff0c 8 Linux都已经不在维护 对应仓库都清空了 不过有时候 xff0c 我们依然需要安装对应小版本 xff0c 比如8 5 这里摸索了一个更新
  • word 插入公式附加右侧编号方法

    主要添加编号的方法就是在公式后面 xff0c 添加 编号 xff0c 输入光标在公式的最后 xff0c 然后回车 效果如下 xff1a 需要注意的有两点 xff1a 1 必须要保证 不属于公式内部 xff08 如果不清楚如何保证 键属不属于

随机推荐

  • iscsiadm命令用法

    启动iscsi守护进程 span class token function service span iscsi start 发现目标 iscsiadm m discovery t sendtargets p 192 168 1 1 326
  • centOS7关闭防火墙

    查看防火墙状态 xff1a systemctl status firewalld service 如图 绿的running表示防火墙开启 执行关闭命令 xff1a systemctl stop firewalld service 再次执行查
  • Linux 包管理基础:apt、yum、dnf 和 pkg常用命令

    介绍 大多数现代的类 Unix 操作系统都提供了一种中心化的机制用来搜索和安装软件 软件通常都是存放在存储库中 xff0c 并通过包的形式进行分发 处理包的工作被称为包管理 包提供了操作系统的基本组件 xff0c 以及共享的库 应用程序 服
  • opensd开源啦 !这套自动化部署OpenStack工具你值得拥有

    2022年8月 xff0c 经openEuler开源社区技术委员会审议通过 xff0c 联通数科正式将opensd开源至openEuler开源社区 opensd是联通数科为解决OpenStack企业级部署的复杂性 xff0c 针对自身Ope
  • 边缘计算的解决方案大集合

    自今年2月的巴塞罗那世界移动通信大会召开以来 xff0c 边缘计算无疑是C位出道 xff0c 爆发释放在人们的视野中 xff0c 成为今年业界最热门的领域之一 顺着5G的东风 xff0c 边缘计算的诞生成为历史必然 xff0c 整个行业都在
  • 计网(笔记版)---外部网络路由协议之BGP协议

  • js中?. 、?? 、??=的用法及含义

    1 可选链运算符 是不是经常遇到这样的错误 TypeError Cannot read properties of null reading 39 xxx 39 引入可选链就是为了解决这个问题 const person 61 id 1 na
  • LXC与Docker介绍

    文章目录 LXCLUX是什么LXC常用命令LXC的使用 Docker容器虚拟化和传统虚拟化的区别Linux NamespacesCGroupsdoeker基本概念docker容器编排 LXC LUX是什么 LXC xff08 LinuX C
  • 树莓派使用CLASH的代理安装软件

    为什么要使用代理 github系列域名不能访问 xff0c curl一键安装用不了 打开CLASH允许局域网功能 树莓派终端登陆 方法一 xff1a 1 编辑 etc profile文件 sudo nano etc profile 2 在最
  • Linux下解决高并发socket最大连接数限制,tcp默认1024个连接

    linux获取TCP连接数 方法一 xff1a admin 64 zabbix ss ant awk 39 NR gt 1 a 1 43 43 END for b in a print b a b 39 ESTAB 535 TIME WAI
  • vncserver的详细配置

    原文地址 xff1a vncserver的详细配置 作者 xff1a OpenTech 1 首先要配置的是服务端 A 确认服务器端是否安装了vncserver 使用rpm qa vnc命令如果收到如下信息说明已经安装了vncserver x
  • Rman备份中常见的问题

    1 xff1a ORA 01031insufficient privileges gpasswd d oracle dba 将oracle移除出dba组 查看oracle属性 uid 61 500 oracle gid 61 500 oin
  • 弱监督学习-snorkel

    1 什么是弱监督学习 弱监督问题旨在研究通过较弱的监督信号来构建预测模型 xff0c 即在少量的标注样本上学习建模 xff0c 达到大量样本上同样的效果 弱监督学习主要分为三类 不确切监督 xff08 inexact supervision
  • 如何用3000元搞定一年100M点对点专线

    温馨提示 xff1a 阅读本文需要先阅读或温习格物资讯早先发布过的 奇葩物花生壳出品蒲公英VPN组网路由 2015年11月格物资讯发布了花生壳打洞路由器蒲公英的试用报告 xff0c 提到在鹏博士接入前提下 xff0c 做P2P组网实现快速的
  • PNETLAB中可以导入的交换机、防火墙等设备镜像

    在网上找了很久 xff0c 想要找到一个设备镜像的下载 xff0c 发现网上全都是一些对于PNET本体安装的炒冷饭 不过经过一个下午的寻找 xff0c 最终在B站一个UP 64 real半吊子工程师 22年的视频里找到了相关的下载平台连接
  • 在Keras中,TimeDistributed层的作用是什么?

    在Keras中 xff0c TimeDistributed层的作用是什么 xff1f 关键词 xff1a python xff0c machine learning xff0c keras xff0c neural network xff0
  • 理解1D、2D、3D卷积神经网络的概念

    目录 引言二维CNN Conv2D一维CNN Conv1D三维CNN Conv3D总结 引言 当我们说卷积神经网络 xff08 CNN xff09 时 xff0c 通常是指用于图像分类的二维CNN 但是 xff0c 现实世界中还使用了其他两
  • 解决vncserver看不到桌面的问题

    解决vncserver看不到桌面的问题 主要参考这里 xff1a http zhidao baidu com link url 61 7Btj0KsV5b986dydoOpElKDpSwriaruP4jxWY6f6pG3Ota kcQbdV
  • 深入理解 keras 中 Dense 层参数

    目录 引言深入理解 Dense 层的用法查看参数输入尺寸输出尺寸示例 xff1a 用法完整示例示例一 最小网络示例二 xff1a 多维度数据示例三 xff1a 特殊情况 xff0c 待讨论 附录 引言 大家或许已经对深度学习不陌生了 不管是
  • 如何在Keras中使用数据生成器(data generators)的详细示例

    目录 动机讲解以前的情况小提示数据产生器Keras脚本 可运行实例结论 动机 您是否曾经不得不加载一个非常消耗内存的数据集 xff0c 以至于希望魔术能够无缝地解决这一问题 xff1f 大型数据集正日益成为我们生活的一部分 xff0c 因为