pytorch源码分析之torch.utils.data.Dataset类和torch.utils.data.DataLoader类

2023-11-10

写在之前
介绍

Pytorch深度学习框架优势之一是python优先,源代码由python代码层和C语言代码层组成,一般只需要理解python代码层就可以深入理解pytorch框架的计算原理。所以学习pytorch源码需要熟练掌握python语言的各种使用技巧。

在处理任何机器学习问题之前都需要数据读取,并进行预处理。Pytorch提供了许多方法使得数据读取和预处理变得很容易。

  • torch.utils.data.Dataset是代表自定义数据集方法的抽象类,你可以自己定义你的数据类继承这个抽象类,非常简单,只需要定义__len____getitem__这两个方法就可以。
  • 通过继承torch.utils.data.Dataset的这个抽象类,我们可以定义好我们需要的数据类。当我们通过迭代的方式来取得每一个数据,但是这样很难实现取batch,shuffle或者多线程读取数据,所以pytorch还提供了一个简单的方法来做这件事情,通过torch.utils.data.DataLoader类来定义一个新的迭代器,用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
    总之,通过torch.utils.data.Datasettorch.utils.data.DataLoader这两个类,使数据的读取变得非常简单,快捷。
这两个抽象类中用到的python知识点

能够熟练的使用python语言的技巧,是理解pytorch源码的关键。在torch.utils.data.Datasettorch.utils.data.DataLoader这两个类中会用到python抽象类的魔法方法,包括__len__(self)__getitem__(self)__iter__(self)

  • __len__(self) 定义当被len()函数调用时的行为(返回容器中元素的个数)
  • __getitem__(self)定义获取容器中指定元素的行为,相当于self[key],即允许类对象可以有索引操作。
  • __iter__(self)定义当迭代容器中的元素的行为

下面通过介绍python定制容器的方式来介绍__len__(self)__getitem__(self)两种方法。
在python中,像序列类型(如列表,元组和字符串)或映射类型(如字典)都属于容器类型。讲定制容器,那就必须要知道,定制容器有关的一些协议:

  • 如果你希望定制的容器是不可变的话,你只需要定义__len__()__getitem__这两个魔法方法。
  • 如果你希望定制的容器是可变的话,除了__len__()__getitem__这两个魔法方法,还需要定义__setitem__()__delitem__()两个方法。

小案例:编写一个不可变的自定义列表,要求记录列表中每个元素被访问的次数。

class CountList:
	def __init__(self, *args):
		self.values = [x for x in args]
		self.count = {
   }.fromkeys(range(len(self.values)),0)
		# 这里使用列表的下标作为字典的键,注意不能用元素作为字典的键
		# 因为列表的不同下标可能有值一样的元素,但字典不能有两个相同的键
	def __len__(self):
		return len(self.values)
	def __getitem__(self, key):
		self.count[key] += 1
		return self.values[key]
c1 = CountList(1,3,5,7,9)
c2 = CountLIst(2,4,6,8,10)

# 调用
c1[1]  ## 3
c2[1]  ## 4
c1[1] + c2[1] 	## 7
c1.count  ## {0:0,1:2,2:0,3:0,4:0}
c2.count  ## {0:0,1:2,2:0,3:0,4:0}	

接下来讲解__iter__(self)方法。这个魔法方法是在python构造迭代器的时候需要定义的。迭代的意思类似于循环,每一次重复的过程被称为一次迭代的过程,而每一次迭代得到的结果会被用来作为下一次迭代的初始值。提供迭代方法的容器称为迭代器,通常接触的迭代器有序列(列表、元组和字符串)还有字典也是迭代器,都支持迭代操作。那么实现迭代器的魔法方法有两个:

  • __iter__()
  • __next__()
    一个容器如果是迭代器,那就必须实现__iter__()魔法方法,这个方法实际上是返回迭代器本身。接下来重点要实现的是__next__()魔法方法,因为它决定了迭代的规则。举个简单的例子:
class Fibs:
	def __init__(self, n=20):
		self.a = 0
		self.b = 1
		self.n = n
	def __iter__(self):
		return self
	def __next__(self):
		self.a, self.b = self.b, self.a + self.b
		if self.a > self.n:
			raise StopIteration
		return self.a

## 调用
fibs = Fibs()
for each in fibs:
	print(each)
## 输出
1
1
2
3
5
8
13
torch.utils.data.Dataset类

源码:

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

一个用来表示数据集的抽象类,其他所有的数据集都应该是这个类的子类,并且需要重写__len____getitem__

torch.utils.data.DataLoader类

DataLoader类源码如下。先看看__init__中的几个重要的输入:1、dataset,这个就是PyTorch已有的数据读取接口(比如torchvision.datasets.ImageFolder)或者自定义的数据接口的输出,该输出要么是torch.utils.data.Dataset类的对象,要么是继承自torch.utils.data.Dataset类的自定义类的对象。2、batch_size,根据具体情况设置即可。3、shuffle,一般在训练数据中会采用。4、collate_fn,是用来处理不同情况下的输入dataset的封装,一般采用默认即可,除非你自定义的数据读取输出非常少见。5、batch_sampler,从注释可以看出,其和batch_size、shuffle等参数是互斥的,一般采用默认。6、sampler,从代码可以看出,其和shuffle是互斥的,一般默认即可。7、num_workers,从注释可以看出这个参数必须大于等于0,0的话表示数据导入在主进程中进行,其他大于0的数表示通过多个进程来导入数据,可以加快数据导入速度。8、pin_memory,注释写得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一个数据拷贝的问题。9、timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。
在__init__中,RandomSampler类表示随机采样且不重复,所以起到的就是shuffle的作用。BatchSampler类则是把batch size个RandomSampler类对象封装成一个,这样就实现了随机选取一个batch的目的。这两个采样类都是定义在sampler.py脚本中,地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py。以上这些都是初始化的时候进行的。当代码运行到要从torch.utils.data.DataLoader类生成的对象中取数据的时候,比如:
train_data=torch.utils.data.DataLoader(…)
for i, (input, target) in enumerate(train_data):

就会调用DataLoader类的__iter__方法,__iter__方法就一行代码:return DataLoaderIter(self),输入正是DataLoader类的属性。因此当调用__iter__方法的时候就牵扯到另外一个类:DataLoaderIter,接下来介绍。

class DataLoader(object):
    r"""
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.

  
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch源码分析之torch.utils.data.Dataset类和torch.utils.data.DataLoader类 的相关文章

随机推荐

  • 【区块链】深度长文:2018新风口,区块链3.0时代即将来临?

    徐小平说 区块链将掀起一场革命 1月9日 徐小平在真格投资组合群里分享了一段关于拥抱区块链时代的内容 并表示不能外传 被泄露的微信截图 岂料 很快去传了出去 他表示 这本是其与被投公司间的 低调 内容分享 现在被人擅自传出去 也没办法 最后
  • Express基本认识

    express是一个基于nodejs 且快速 开放的一个web开发框架 安装命令 yarn add express 查看express的所有版本 npm view express versions 搭建基本的express程序 const
  • ./configure --prefix=

    一直用这个选项prefix 但不知道 啥意思 转载自 inux安装软件采用源码安装灵活自由 适用于不同的平台 维护也十分方便 源码的安装一般由3个步骤组成 配置 configure 编译 make 安装 make install 具体的安装
  • centos7设置开机为命令行启动

    图形界面默认安装之后 每次启动都是图形界面启动 图形界面需要占用系统大量的内存和CPU资源 对于个人电脑和服务器 将Centos 默认启动改为文本方式 会显著提高运行效率 方法一 不修改默认启动方式 root模式下 init 3 gt 切换
  • sqlserver查看执行计划

    方式一 通过Microsoft sql server management studio工具栏中的 显示估计的执行计划 按钮 选中SQL 然后点击该按钮 SQL就会给我们选中SQL的图形执行计划 方式二 set showplan all o
  • 【无监督学习】0、有监督学习、无监督学习、半监督学习

    文章目录 一 有监督学习 二 半监督学习 三 无监督学习 3 1 对比式学习 一 有监督学习 有监督学习最大的特点就是数据集是带标签的 如有监督分类任务 就是给每张图都分配一个真实标签 表示这张图是 dog cat 或者是 bird 而标签
  • 【iOS】内存管理

    文章目录 前言 理解引用计数 引用计数原理 属性存取方法中的内存管理 自动释放池 保留环 以ARC简化引用计数 使用ARC时必须遵守的命名规则 变量的内存管理语义 ARC如何清理实例变量 覆写内存管理的方法 在dealloc方法中只释放应用
  • Lua 15分钟快速上手(上)

    本系列相关文章 Flutter 热更新及动态UI生成 Lua 15分钟快速上手 上 Lua 15分钟快速上手 下 Lua与C语言的互相调用 LuaDardo中Dart与Lua的相互调用 在之前的博客 Flutter 热更新及动态UI生成 一
  • 【python环境搭建】conda 安装过程中无法激活 python 虚拟环境问题

    目录 1 概要 2 解决办法 1 概要 最近重新学习python 需要搭建conda 环境 遇到一个懵逼的问题 C Users 67656 gt conda activate Date CommandNotFoundError Your s
  • Unity API Camera摄像机的使用

    Camera main 返回主摄像机的Camera组件 第一个启用的标签为 MainCamera 的摄像机 只读 场景中的主要摄像机 如果场景中没有这样的摄像机 则返回null 此属性在内部使用FindGameObjectsWithTag
  • 合宙Air724UG LuatOS-Air LVGL API控件--按钮 (Button)

    按钮 Button 按钮控件 这个就不用多说了 界面的基础控件之一 示例代码 按键回调函数 event handler function obj event if event lvgl EVENT CLICKED then print Cl
  • 执行Shell脚本的4种方法

    假设我们编写好的shell脚本的文件名为hello sh 文件位置在 root bin目录中并已有执行权限 添加权限的方法 chmod x hello sh 1 方法一 切换到shell脚本所在的目录 此时 称为工作目录 执行shell脚本
  • anaconda代码

    因为老是不记得代码 要找来找去的 索性自己写一下怕忘记 windos conda info envs 查看本机所有的虚拟环境 conda remove n 你自己的环境的名字 all 删除虚拟环境 conda create n 自己想取的名
  • linux网络服务network没了,Linux网络服务(network service)管理

    Linux操作系统中重新启动网络的方法 网页链接 https ywnz com linux 4463 html 1 网络管理员服务 这是使用命令行重新启动网络的最简单方法 它等同于图形化方式 重新启动Network Manager服务 su
  • Unity2D修改Sprite颜色和透明度

    Unity2D修改Sprite颜色和透明度 简单注意原理在前边 后面是实现方法 首先创建一个Sprite 最好选择纯白色的Sprite 选择的图片需要是白色的才会在修改颜色后有明显的显示 因为颜色修改后 它的最终显示是本来的图片的颜色与修改
  • QT 信号和槽

    信号和槽是一种高级接口 应用于对象之间的通信 它是 QT 的核心特性 要正确的处理信号和槽 必须借助一个称为 moc Meta Object Compiler 的 QT 工具 该工具是一个 C 预处理程序 它为高层次的事件处理自动生成所需要
  • 解决bug“ImportError: numpy.core.multiarray failed to import”

    解决bug ImportError numpy core multiarray failed to import 在这之前升级scikit image从老版本0 13 0到0 17 2 但运行pycharm工程出现如下bug from fi
  • msys2 安装 mingw64

    https blog csdn net zhuwade article details 121944279
  • vue+element 图片右上角添加删除小×、按钮预览图片

    思维方法 这个问题实际就是一个思维方式的问题 我最开始思考的就很复杂 后来我同事给出的解决方法就好 方法是 在判断有图片的时候 在图片的右上角加上一个小 的图片 在这个图片上加方法 点击就把图片清空 所以有的时候 一件事情不能想的太复杂 代
  • pytorch源码分析之torch.utils.data.Dataset类和torch.utils.data.DataLoader类

    写在之前 介绍 Pytorch深度学习框架优势之一是python优先 源代码由python代码层和C语言代码层组成 一般只需要理解python代码层就可以深入理解pytorch框架的计算原理 所以学习pytorch源码需要熟练掌握pytho