PyTorch中的Dataset、Dataloader和_DataloaderIter

2023-05-16

Dataset

Pytorch中数据集被抽象为一个抽象类torch.utils.data.Dataset,所有的数据集都应该继承这个类,并override以下两项:

  • __len__:代表样本数量。len(obj)等价于obj.__len__()
  • __getitem__:返回一条数据或一个样本。obj[index]等价于obj.__getitem__。建议将节奏的图片等高负载的操作放到这里,因为多进程时会并行调用这个函数,这样做可以加速。

dataset中应尽量只包含只读对象,避免修改任何可变对象。因为如果使用多进程,可变对象要加锁,但后面讲到的dataloader的设计使其难以加锁。如下面例子中的self.num可能在多进程下出问题:

class BadDataset(Dataset):
	def __init__(self):
		self.datas = range(100)
		self.num = 0 # read data times
	def __getitem__(self, index):
		self.num += 1
		return self.datas[index]

Dataloader

官方documentation

Dataset负责表示数据集,它可以每次使用__getitem__返回一个样本。而torch.utils.data.Dataloader提供了对batch的处理,如shuffle等。Dataset被封装在了Dataloader中。

Dataloader的构造函数如下:

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
	batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, 
	pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

部分参数解释如下:

  • num_workers:使用的子进程数,0为不使用多进程。
    • worker_init_fn: 默认为None,如果不是None,这个函数将被每个子进程以子进程id([0, num_workers - 1]之间的数)调用
  • sample:采样策略,若这个参数有定义,则shuffle必须为False
  • pin_memory:是否将tensor数据复制到CUDA pinned memory中,pin memory中的数据转到GPU中会快一些
  • drop_last:当dataset中的数据数量不能整除batch size时,是否把最后$\text{len(dataset)} \bmod \text{batch_size} $个数据丢掉
  • collate_fn:把一组samples打包成一个mini-batch的函数。可以自定义这个函数以处理损坏数据的情况(先在__getitem__函数中将这样的数据返回None,然后再在collate_fn中处理,如丢掉损坏数据or再从数据集里随机挑一张),但最好还是确保dataset里所有数据都能用。

另外,Dataloader是个iterable,可以进行相关迭代操作。

DataLoaderIter

DatasetDataloaderDataLoaderIter是层层封装的关系,最终在内部使用DataLoaderIter进行迭代。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch中的Dataset、Dataloader和_DataloaderIter 的相关文章

随机推荐

  • YOLOv5识别目标的实时坐标打印

    引言 这个功能看似鸡肋 xff0c 但对于无人机目标识别与追踪有重要意义 xff0c 通过目标在摄像头视野的坐标位置 xff0c 可以推算出无人机相对与目标的位置 xff0c 从而对无人机进行位置矫正 因此 xff0c 添加代码打印坐标并不
  • 六、WebRTC中ICE的实现

    一 Candidate种类 amp 优先级 高到底 xff1a host srflx prflx relay 同一局域网内通过host类型的Candidate在内网建立连接 非同一局域网 xff0c 隔断从STUN TURN服务器中收集sr
  • 七、WebRTC中的SDP

    一 SDP标准规范 格式 xff1a lt type gt 61 lt value gt SDP 会话层 媒体层 媒体音频 媒体视频 二 WebRTC中的SDP的整体结构 1 媒体信息 m 61 行中描述媒体类型 传输类型 Playload
  • linux 信号量sem

    一 信号量 信号量如同一盏红绿信号灯 xff0c 用于临界资源 xff08 如公路 人行道 xff09 的管理 信号量是一种特殊的变量 xff0c 访问具有原子性 P等待 xff1a 信号量的值为0时 xff0c 不能减 xff0c 则进行
  • 1-4 实验3 串口通信

    串口通信 1 实验内容 xff1a PC端串口调试助手向板子发送数据 xff0c 板子接受到数据后 xff0c 再把数据发送回给PC端串口调试助手 2 串口发送接受数据的基本步骤 xff1a 初始化串口 xff08 设置波特率 中断等 xf
  • 1-6 实验5 无线温度检测实验

    无线温度检测实验 1 实验内容 xff1a 协调器建立ZigBee无线网络 xff0c 终端节点自动加入网络 xff0c 然后终端节点周期性地采集温度并将数据发送到协调器 协调器接受数据并通过串口把接受到的数据传给PC端的串口调试助手 2
  • 1-11 实验9 网络管理实验1 获取自身的和父节点网络地址、MAC地址

    p p p style color rgb 51 51 51 font family Arial font size 14px line height 26px 获取自身的和父节点网络地址 MAC地址 p p style color rgb
  • 1-14 实验11 获取网络拓扑

    获取网络拓扑 1 实验内容 xff1a PC端串口调试助手向协调器发送命名 topology 协调器接受到命令后 xff0c 将网络拓扑信息发送到PC机串口调试助手上 2 知识点 xff1a 在1 11 实验9 网络管理实验1 获取自身和父
  • S 串口编程 详解3 串口的初始化、打开/关闭

    串口编程 详解3 串口的初始化 程序打开串口 xff0c 采用两种方法 xff1a 1 程序启动 xff0c 调用OnInitDialog 函数 xff0c 打开串口 xff0c 缺省串口号为COM1 xff0c 如果COM1不存在或被占用
  • 求关键路径(包含邻接表的建立、拓扑排序)

    include lt stdio h gt include lt stdlib h gt typedef struct node int adjvex 邻接点域 int info 边上的信息 struct node next 指向下一个邻接
  • FPGA串口回环实验

    本文将从个人理解的角度 xff0c 解释FPGA串口通信的原理 xff0c 并进行实战演示 1 写在前面的话 串口通信是初学FPGA必过的一道坎 xff0c 如果能够在不参考任何资料的情况下自己手搓一套串口回环的代码 xff0c Verio
  • Debug Assertion Failed!解决方法详解

    1 野指针 2 内存泄露 解决方法 1 看一看你的程序里是不是有 ASSERT xff08 xff09 或 VERIFY xff08 xff09 语句 这两个宏是用来测试它的参数是否为真的 出现你说的 xff0c 这说明你的指针或表达试有问
  • 用tftp的方式在u_boot下 烧写uImage内核

    用 u boot 进行下载 uImage 一种 kernel 镜像文件 首先 把编译好的 uImage 文件放在 tftpboot 目录下 用网线把开发板和电脑连上 但PC上的网卡显示是没连接的 xff0c 这一点是没有关系的 xff0c
  • 利用NFS服务挂载NFS根文件系统

    嵌入式Linux根文件系统 xff0c 简单地说 xff0c 根文件系统就是一种目录结构 注意根文件系统和普通的文件系统的区别 常见的Linux根文件系统有 xff1a xff08 1 xff09 NFS xff08 网络根文件系统 xff
  • 数据校验之Checksum算法

    校验和 xff08 Checksum xff09 是网络协议使用的数据错误检测方法 xff0c 并且被认为比LRC xff08 纵向冗余校验 xff0c Longitudinal Redundancy Check xff0c LRC xff
  • 位序转字符串的一种高效方法

    include lt stdio h gt include lt stdlib h gt include lt malloc h gt include lt string h gt include lt arpa inet h gt def
  • OpenSIPS实战(一):OpenSIPS使用简介

    1 OpenSIPS是什么 OpenSIPS xff08 Open SIP Server xff09 是一个成熟的开源SIP服务器实现 可以作为SIP代理 路由器 但OpenSIPS不仅仅是一个SIP代理 路由器 xff0c 因为它包含了应
  • Floyd判圈算法(龟兔赛跑算法, Floyd's cycle detection)及其证明

    问题 xff1a 如何检测一个链表是否有环 xff08 循环节 xff09 xff0c 如果有 xff0c 那么如何确定环的起点以及环的长度 空间要求 xff1a 不能存储所经过的的每一个点 举例 xff1a x 0 61 1 x 0 61
  • Ubuntu配置GPU版本pytorch环境(含NVIDIA驱动+Cuda+Cudnn)

    本文更新于2018年8月底 概述 步骤如下 xff1a 1 安装Ubuntu 2 安装NVIDIA 显卡驱动 2 安装NVIDIA Cuda 3 安装NVIDIA CuDNN 4 安装GPU版本的PyTorch 安装Ubuntu 系统版本选
  • PyTorch中的Dataset、Dataloader和_DataloaderIter

    Dataset Pytorch中数据集被抽象为一个抽象类torch utils data Dataset xff0c 所有的数据集都应该继承这个类 xff0c 并override以下两项 xff1a len xff1a 代表样本数量 len