PyTorch:Dataset()与Dataloader()的使用详解

2023-05-16

目录

1、Dataset类的使用

2、Dataloader类的使用

3、总结


Dataset类与Dataloader类是PyTorch官方封装的用于在数据集中提取一个batch的训练用数据的接口,其实我们也可以自定义获取每个batch的方法,但是对于大数据量的数据集,直接用封装好的接口会很大程度上提升效率。

一般情况下,Dataset类与Dataloader类是配合着使用的,Dataset负责整理数据,Dataloader负责在整理好的数据中按照一定的规则取出batch_size个数据来供网络训练使用。

1、Dataset类的使用

Dataset用以整理数据集。我们整理数据的目的是为了Dataloader可以方便的从整理后的和数据中获取一个batch的数据来供网络进行训练。

先看一下官方的Dataset的源码:

class Dataset(object):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]

很明显,这个类内部什么方法的实现都没有,就是用来让我们继承重写的。当我们继承该类时,必须重写里面的__getitem__(self, index)方法。该方法定义了使用索引值来查找元素的方法,即假如我们定义一个自己的训练数据集实例traindata,如果想使用traindata[index]的方式来获取索引为index的数据,我们就得实现__getitem__方法。这样当我们调用traindata[index]索引数据时,其实就是自动调用__getitem__(self, index)方法来实现的。另外,我们还可以重写__len__(self)方法,用以使用len(traindata)方法来获取我们整个数据集的数量。如果还不清楚,可以细细品一下下面的例子:

class TrainData(Dataset):  # 继承Dataset类并重写相关的方法
    ...
    def __getitem__(self, index):
        '''编写自己的数据获取方式'''
        return [x_data, y_lable]

    def __len__(self):
        '''编写获取数据集大小的实现方式'''
        return length


traindata = TrainData(mydataset)   # 定义一个实例
first = traindata[0]       # 获取数据集中的第一组数据,会自动调用__getitem__
length = len(traindata)    # 获取数据集的数据量的方法,会自动调用__len__

2、Dataloader类的使用

整体上来说,Dataloader类就是从上面封装好的数据中按照给定的方式来一次一次地抽取一个batch的数据来供网络进行训练,其内部使用的是yield生成器机制。Dataloader不用继承重写,我们直接实例化就行。下面我们接着上面的例子来继续了解下Dataloader从数据集中取出一个batch数据的过程:

首先,定义一个Dataloader实例gen_train:

gen_train = Dataloader(traindata, batch_size=4, num_workers=4, pin_memory=True, drop_last=True, collate_fn=my_collate_fn)

关于有关参数的说明(没用到的参数就不解释了):

1、traindata(Dataset): 传入的数据集,按自己定义的Dataset实例名来传入,我这里是traindata
2、batch_size(int, optional): 每个batch有多少个样本
3、num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
4、pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
5、drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为4,而一个epoch只有100个样本,那么训练的时候后面的2个因为不满足组成一个batch就被扔掉了。如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
6、collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数

可以看到,gen_train从traindata中返回的是一个含有batch_size(4)个数据([x_data, y_label])的mini_batch。

下面我们分析分析这个过程是咋实现的。首先,DataLoader(object)源码中有下面这么一段代码:

                。。。。。。
if sampler is None:  # give default samplers
    if self.dataset_kind == _DatasetKind.Iterable:
        # See NOTE [ Custom Samplers and IterableDataset ]
        sampler = _InfiniteConstantSampler()
    else:  # map-style
        if shuffle:
            sampler = RandomSampler(dataset)
        else:
            sampler = SequentialSampler(dataset)
                。。。。。。

按照上面的设置,sampler默认是None,我们没有定义要打乱数据(即shuffle为False),则接下来会调用

sampler = SequentialSampler(dataset)

再来看看这个方法是怎么实现的:

class SequentialSampler(Sampler):
    r"""Samples elements sequentially, always in the same order.

    Arguments:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

主要看__iter__部分,明显的,假设数据集共有n个数据,这是一个返回的sampler就是数据集长度[0,1,2,......,n-1]序号的迭代器。关于怎么迭代,我们回到DataLoader(object)源码中继续往先看,会发现这么几条代码:

if batch_size is not None and batch_sampler is None:
    # auto_collation without custom batch_sampler
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)

首先说一下,这个代码就是从上一步的迭代器sampler中取出batch_size个序号,batch_size之前我们设置的是4,所以就是取出4个序号(索引),用以后面从traindata中取出batch_size个数据,来看一下BatchSampler方法的迭代方式的实现,注意这里的yield机制:

class BatchSampler(Sampler):
        。。。。。。
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch
        。。。。。。

所以,到这里我们一个batch_size的数据的索引就已经有了,后面就是调用多线程或单线程机制来取出对应的数据traindata[i]了。回到DataLoader(object)源码中,在往下看,就是下面这段代码了:

def __iter__(self):
    if self.num_workers == 0:
        return _SingleProcessDataLoaderIter(self)
    else:
        return _MultiProcessingDataLoaderIter(self)

这段代码就是DataLoader的迭代器的实现方式了,具体的单多线程实现就不详细展开了。此时我们已经完成了获取本次迭代所需要的数据的索引值,接下来即使按照索引在traindata中找到相应的数据并一起返回这个mini_batch了。比如我们可以这样获取数据并用于训练:

for iteration, batch in enumerate(gen_train):
    if iteration >= epoch_size:  # 判断是否到达一个epoch的迭代次数(len(traindata)/batchsize)
        break
    x_datas, y_labels= batch[0], batch[1]  # 获取batch中的数据和标签,用于训练
                ......

我们就可以使用这批数据进行一次网络的训练了,这么周而复始,直至达到我们设置的epoch。

3、总结

一般情况下,Dataset类与Dataloader类是配合着使用的,Dataset负责整理数据,Dataloader负责从Dataset整理好的数据中按照一定的规则取出batch_size个数据来供网络训练使用。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch:Dataset()与Dataloader()的使用详解 的相关文章

随机推荐

  • springmvc实现文件上传与下载【单张及多张图片】

    一 springmvc实现文件上传的步骤 1 实现上传单张图片 1 导入pom 坐标 span class token comment lt 文件上传 gt span span class token tag span class toke
  • SpringBoot 搭建的个人博客

    介绍 blog是基于SpringBoot 搭建的个人博客 xff0c 响应式 前端技术 xff1a html css js jq bootstrap 后台技术 xff1a springboot thymeleaf mybatis mysql
  • SpringCloud(一)微服务概述

    文章目录 微服务概述什么是微服务微服务与微服务架构微服务的优缺点优点缺点 微服务的技术栈为什么选SpringCloud作为微服务架构选型依据当前各大IT公司的微服务架构各微服务的框架对比 SpringCloud入门概述SpringCloud
  • SpringCloud(二)入门案例之支付模块与订单模块的调用

    SpringCloud xff08 一 xff09 微服务概述 xff1a https blog csdn net weixin 45606067 article details 108481733 构建SpringCloud工程 概述 x
  • SpringCloud(三)Eureka服务注册中心

    文章目录 1 Eureka基础知识什么是服务治理什么是服务注册Eureka两大组件 2 Eureka介绍及原理理解介绍原理 3 单机版Eureka 构建步骤4 集群版Eureka 构建步骤Eureka集群原理说明EurekaServer集群
  • SpringCloud(四)zookeeper介绍及原理

    SpringCloud xff08 四 xff09 zookeeper介绍及原理 xff1a https blog csdn net weixin 45606067 article details 108499344 Zookeeper服务
  • docker 的安装 - 常用命令 - 应用部署

    文章目录 1 Docker简介什么是虚拟化什么是Docker容器与虚拟化比较Docker 组件1 Docker服务器与客户端2 Docker镜像与容器3 Register xff08 注册中心 xff09 2 Docker安装与启动安装Do
  • SpringCloud(六)Ribbon负载均衡服务调用

    Ribbon负载均衡 概述 是什么 Spring Cloud Ribbon是基于Netflix Ribbon实现的一套客户端 负载均衡工具 简单的说 xff0c Ribbon是Netflix发布的开源项目 xff0c 主要功能是提供客户端的
  • Python:map()函数使用详解

    1 函数定义 xff1a map function iterable 2 作用 xff1a 该函数通过接收一个函数function作为处理函数 xff0c 然后接收一个参数序列iterable xff0c 并使用处理函数对序列中的每个元素逐
  • SpringCloud(五)Consul服务注册与发现

    SpringCloud xff08 四 xff09 zookeeper介绍及原理 xff1a https blog csdn net weixin 45606067 article details 108538357 Consul简介 是什
  • SpringCloud(七)OpenFeign负载均衡服务调用

    1 概述 1 OpenFeign是什么 官网解释 xff1a https cloud spring io spring cloud static Hoxton SR1 reference htmlsingle spring cloud op
  • Zookeeper概述 | 安装部署(Windows和Linux)

    Zookeeper 一 Zokeeper 门 1 概述 Zookeeper是一个开源的分布式的 xff0c 为分布式应用提供协调服务的Apache项目 ZooKeeper is a centralized service for maint
  • Zookeeper内部原理

    Zookeeper概述 安装部署 xff08 Windows和Linux xff09 xff1a https blog csdn net weixin 45606067 article details 108619378 1 选举机制 面试
  • jsp和servlet的区别

    基本介绍 Servlet xff1a Servlet 是一种服务器端的Java应用程序 xff0c 具有独立于平台和协议的特性 xff0c 可以生成动态的Web页面 它担当客户请求 xff08 Web浏览器或其他HTTP客户程序 xff09
  • Session学习笔记

    1 session 简介 session 是我们 jsp 九大隐含对象的一个对象 session 称作域对象 xff0c 他的作用是保存一些信息 xff0c 而 session 这个域对象是一次会话期间使用同一个对象 所以这个对象可以用来保
  • session和cookie 区别【面试】

    说说Cookie和Session的区别 xff1f 1 存取方式的不同 xff08 Cookie只能保存ASCII xff0c Session可以存任意数据类型 xff09 Cookie中只能保管ASCII字符串 xff0c 假如需求存取U
  • JSP 九大内置对象,四大域对象

    JSP的九大内置对象 内置对象名 类型 request HttpServletRequest response HttpServletResponse session HttpSession application ServletConte
  • SpringCloud(八)Hystrix断路器

    文章目录 1 概述分布式系统面临的问题是什么能干嘛官网资料Hystrix官宣 xff0c 停更进维 2 Hystrix重要概念3 hystrix案例构建项目高并发测试故障现象和导致原因上诉结论如何解决 xff1f 解决的要求服务降级服务熔断
  • 谷粒学院(一)项目介绍

    一 项目背景 在线教育顾名思义 xff0c 是以网络为介质的教学方式 xff0c 通过网络 xff0c 学员与教师即使相隔万里也可以开展教学活动 xff0c 此外 xff0c 借助网络课件 xff0c 学员还可以随时随地进行学习 xff0c
  • PyTorch:Dataset()与Dataloader()的使用详解

    目录 1 Dataset类的使用 2 Dataloader类的使用 3 总结 Dataset类与Dataloader类是PyTorch官方封装的用于在数据集中提取一个batch的训练用数据的接口 xff0c 其实我们也可以自定义获取每个ba