Pytorch数据读取与预处理实现与探索

2023-11-02

  在炼丹时,数据的读取与预处理是关键一步。不同的模型所需要的数据以及预处理方式各不相同,如果每个轮子都我们自己写的话,是很浪费时间和精力的。Pytorch帮我们实现了方便的数据读取与预处理方法,下面记录两个DEMO,便于加快以后的代码效率。

  根据数据是否一次性读取完,将DEMO分为:

  1、串行式读取。也就是一次性读取完所有需要的数据到内存,模型训练时不会再访问外存。通常用在内存足够的情况下使用,速度更快。

  2、并行式读取。也就是边训练边读取数据。通常用在内存不够的情况下使用,会占用计算资源,如果分配的好的话,几乎不损失速度。

  Pytorch官方的数据提取方式尽管方便编码,但由于它提取数据方式比较死板,会浪费资源,下面对其进行分析。

串行式读取

DEMO代码

import torch 
from torch.utils.data import Dataset,DataLoader 
   
class MyDataSet(Dataset):# ————1————
  def __init__(self):    
    self.data = torch.tensor(range(10)).reshape([5,2])
    self.label = torch.tensor(range(5))

  def __getitem__(self, index):   
    return self.data[index], self.label[index]

  def __len__(self):    
    return len(self.data)
  
my_data_set = MyDataSet()# ————2————
my_data_loader = DataLoader(
  dataset=my_data_set,   # ————3————
  batch_size=2,          # ————4————
  shuffle=True,          # ————5————
  sampler=None,          # ————6————
  batch_sampler=None,    # ————7———— 
  num_workers=0 ,        # ————8———— 
  collate_fn=None,       # ————9———— 
  pin_memory=True,       # ————10———— 
  drop_last=True         # ————11————
)

for i in my_data_loader: # ————12————
  print(i)

  注释处解释如下:

  1、重写数据集类,用于保存数据。除了 __init__() 外,必须实现 __getitem__() 和 __len__() 两个方法。前一个方法用于输出索引对应的数据。后一个方法用于获取数据集的长度。

  2~5、 2准备好数据集后,传入DataLoader来迭代生成数据。前三个参数分别是传入的数据集对象、每次获取的批量大小、是否打乱数据集输出。

  6、采样器,如果定义这个,shuffle只能设置为False。所谓采样器就是用于生成数据索引的可迭代对象,比如列表。因此,定义了采样器,采样都按它来,shuffle再打乱就没意义了。

  7、批量采样器,如果定义这个,batch_size、shuffle、sampler、drop_last都不能定义。实际上,如果没有特殊的数据生成顺序的要求,采样器并没有必要定义。torch.utils.data 中的各种 Sampler 就是采样器类,如果需要,可以使用它们来定义。

  8、用于生成数据的子进程数。默认为0,不并行。

  9、拼接多个样本的方法,默认是将每个batch的数据在第一维上进行拼接。这样可能说不清楚,并且由于这里可以探究一下获取数据的速度,后面再详细说明。

  10、是否使用锁页内存。用的话会更快,内存不充足最好别用。

  11、是否把最后小于batch的数据丢掉。

  12、迭代获取数据并输出。

速度探索

  首先看一下DEMO的输出:

  输出了两个batch的数据,每组数据中data和label都正确排列,符合我们的预期。那么DataLoader是怎么把数据整合起来的呢?首先,我们把collate_fn定义为直接映射(不用它默认的方法),来查看看每次DataLoader从MyDataSet中读取了什么,将上面部分代码修改如下:

my_data_loader = DataLoader(
  dataset=my_data_set,    
  batch_size=2,           
  shuffle=True,           
  sampler=None,         
  batch_sampler=None,    
  num_workers=0 ,        
  collate_fn=lambda x:x,  #修改处
  pin_memory=True,       
  drop_last=True         
)

  结果如下:

  输出还是两个batch,然而每个batch中,单个的data和label是在一个list中的。似乎可以看出,DataLoader是一个一个读取MyDataSet中的数据的,然后再进行相应数据的拼接。为了验证这点,代码修改如下:

import torch 
from torch.utils.data import Dataset,DataLoader 
   
class MyDataSet(Dataset): 
  def __init__(self):    
    self.data = torch.tensor(range(10)).reshape([5,2])
    self.label = torch.tensor(range(5))

  def __getitem__(self, index):   
    print(index)          #修改处2
    return self.data[index], self.label[index]

  def __len__(self):    
    return len(self.data)
  
my_data_set = MyDataSet() 
my_data_loader = DataLoader(
  dataset=my_data_set,    
  batch_size=2,           
  shuffle=True,           
  sampler=None,         
  batch_sampler=None,    
  num_workers=0 ,        
  collate_fn=lambda x:x,  #修改处1
  pin_memory=True,       
  drop_last=True         
)

for i in my_data_loader:  
  print(i) 

  输出如下:

  验证了前面的猜想,的确是一个一个读取的。如果数据集定义的不是格式化的数据,那还好,但是我这里定义的是tensor,是可以直接通过列表来索引对应的tensor的。因此,DataLoader的操作比直接索引多了拼接这一步,肯定是会慢很多的。一两次的读取还好,但在训练中,大量的读取累加起来,就会浪费很多时间了。

  自定义一个DataLoader可以证明这一点,代码如下:

import torch 
from torch.utils.data import Dataset,DataLoader 
from time  import time
   
class MyDataSet(Dataset): 
  def __init__(self):    
    self.data = torch.tensor(range(100000)).reshape([50000,2])
    self.label = torch.tensor(range(50000))

  def __getitem__(self, index):    
    return self.data[index], self.label[index]

  def __len__(self):    
    return len(self.data)

# 自定义DataLoader
class MyDataLoader():
  def __init__(self, dataset,batch_size):
    self.dataset = dataset
    self.batch_size = batch_size
  def __iter__(self):
    self.now = 0
    self.shuffle_i = np.array(range(self.dataset.__len__())) 
    np.random.shuffle(self.shuffle_i)
    return self
 
  def __next__(self): 
    self.now += self.batch_size
    if self.now <= len(self.shuffle_i):
      indexes = self.shuffle_i[self.now-self.batch_size:self.now]
      return self.dataset.__getitem__(indexes)
    else:
      raise StopIteration

# 使用官方DataLoader
my_data_set = MyDataSet() 
my_data_loader = DataLoader(
  dataset=my_data_set,    
  batch_size=256,           
  shuffle=True,           
  sampler=None,         
  batch_sampler=None,    
  num_workers=0 ,        
  collate_fn=None,  
  pin_memory=True,       
  drop_last=True         
)

start_t = time()
for t in range(10):
  for i in my_data_loader:  
    pass
print("官方:", time() - start_t)
 
 
#自定义DataLoader
my_data_set = MyDataSet() 
my_data_loader = MyDataLoader(my_data_set,256)

start_t = time()
for t in range(10):
  for i in my_data_loader:  
    pass
print("自定义:", time() - start_t)

  运行结果如下:

  以上使用batch大小为256,仅各读取10 epoch的数据,都有30多倍的时间上的差距,更大的batch差距会更明显。另外,这里用于测试的每个数据只有两个浮点数,如果是图像,所需的时间可能会增加几百倍。因此,如果数据量和batch都比较大,并且数据是格式化的,最好自己写数据生成器。

并行式读取

DEMO代码

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader 
from torchvision import transforms 
from torchvision.datasets import ImageFolder  
  
path = r'E:\DataSets\ImageNet\ILSVRC2012_img_train\10-19\128x128'
my_data_set = ImageFolder(            #————1————
  root = path,                        #————2————
  transform = transforms.Compose([    #————3————
    transforms.ToTensor(),
    transforms.CenterCrop(64)
  ]),
  loader = plt.imread                 #————4————
)
my_data_loader = DataLoader(
  dataset=my_data_set,      
  batch_size=128,             
  shuffle=True,             
  sampler=None,             
  batch_sampler=None,        
  num_workers=0,            
  collate_fn=None,           
  pin_memory=True,           
  drop_last=True 
)           

for i in my_data_loader: 
  print(i)

  注释处解释如下:

  1/2、ImageFolder类继承自DataSet类,因此可以按索引读取图像。路径必须包含文件夹,ImageFolder会给每个文件夹中的图像添加索引,并且每张图像会给予其所在文件夹的标签。举个例子,代码中my_data_set[0] 输出的是图像对象和它对应的标签组成的列表。

  3、图像到格式化数据的转换组合。更多的转换方法可以看 transform 模块。

  4、图像法的读取方式,默认是PIL.Image.open(),但我发现plt.imread()更快一些。

  由于是边训练边读取,transform会占用很多时间,因此可以先将图像转换为需要的形式存入外存再读取,从而避免重复操作。

  其中transform.ToTensor()会把正常读取的图像转换为torch.tensor,并且像素值会映射至$[0,1]$。由于plt.imread()读取png图像时,像素值在$[0,1]$,而读取jpg图像时,像素值却在$[0,255]$,因此使用transform.ToTensor()能将图像像素区间统一化。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch数据读取与预处理实现与探索 的相关文章

  • Cucumber DataTable 错误 - io.cucumber.datatable.UndefinedDataTableTypeException:无法将 DataTable 转换为 cucumber.api.DataTable

    尝试使用 cucumber selenium java intelliJ 运行场景 但在其中一个步骤中出现有关 DataTable 的错误 在我开始使用测试运行程序并更改周围的一些内容之前 数据表工作正常并正确转换该步骤的参数 但我就是无法
  • 是否有最新的 Facebook Java SDK? [关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 好像没找到最近更新的 如果没有 是否有一个好的 Java 库来执行与 Facebook 的 API 交
  • Python unicode 字符代码?

    有没有办法将 Unicode 字符 插入 Python 3 中的字符串 例如 gt gt gt import unicode gt gt gt string This is a full block s unicode charcode U
  • 字典的嵌套列表

    我正在尝试创建dict通过嵌套list groups Group1 A B Group2 C D L y x 0 for y in x if y x 0 for x in groups d k v for d in L for k v in
  • 如何使用文本和?

    我一直在关注this https github com tensorflow models tree master textsum使用 textsum 的链接 我已经使用提供的命令训练了模型 但我在 textsum log root 目录中
  • 改变 Java 中凯撒移位的方向

    用户可以通过选择 1 向左或 2 向右移动字母来选择向左或向右移动 左边工作正常 右边不行 现在它显示了完全相同的循环 但我已经改变了所有 and 以不同的方式进行标记 最终我总是得到奇怪的字符 如何让程序将字符向相反方向移动 如果用户输入
  • 配置jmxremote时无法正常停止tomcat

    我添加了一个jmxremotecatalina bat中的配置 set JAVA OPTS Dcom sun management jmxremote port 9004 Dcom sun management jmxremote ssl
  • Jetty Plugin 9启动不喜欢icu4j-2.6.1.jar

    我对 mortbay 的 Maven jetty 插件 6 有相同的配置
  • JAXB 编组器无参数默认构造函数

    我想从 java 库中编组一个 java 对象 当使用 JAXB marschaller 编组 java 对象时 我遇到了一个问题 A 类没有无参数默认构造函数 我使用Java Decompiler来检查类的实现 它是这样的 public
  • 如何检查日期字符串的有效性?

    在我的项目中 我需要检查日期字符串是否计算为正确的日期对象 我决定允许 yyyy MM dd 和日期格式 年 月 日 和 年 月 日 小时 分钟 我如何检查它们是否有效 我的代码为 1980 01 01 和一些奇怪的日期 如 3837 05
  • 在 pip.conf 中指定多个可信主机

    这是我尝试在我的中设置的 etc pip conf global trusted host pypi org files pythonhosted org 但是 它无法正常工作 参考 https pip pypa io en stable
  • Python Flask 是否定义了路由顺序?

    在我看来 我的设置类似于以下内容 app route test def test app route
  • JPA 将 BigDecimal 作为整数保存在数据库中

    我在数据库中有这个字段 ITEMCOST NUMERIC 13 DEFAULT 0 NOT NULL 在JAVA中 Entity中的字段定义如下 Column name ITEMCOST private BigDecimal itemCos
  • 如何将库添加到 LIBGDX 项目的依赖项 gradle

    一切都在问题中 我已经尝试了在 SO 和其他网站中找到的所有答案 但没有运气 这就是我迄今为止尝试过的 adding compile fileTree dir lib include jar 到我的 build gradle adding
  • Scrapy 蜘蛛无法工作

    由于到目前为止没有任何效果 我开始了一个新项目 python scrapy ctl py startproject Nu 我完全按照教程操作 创建了文件夹和一个新的蜘蛛 from scrapy contrib spiders import
  • CXF:通过 SOAP 发送对象时如何排除某些属性?

    我使用 Apache CXF 2 4 2 当我将数据库中的某个对象返回给用户时 我想排除一些属性 例如密码 我怎样才能做到这一点无需创建临时的班级 有这方面的注释吗 根据 tomasz nurkiewicz 评论我应该使用 XmlTrans
  • Java 中的微分方程

    我正在尝试用java创建一个简单的SIR流行病模型模拟程序 基本上 SIR 由三个微分方程组定义 S t l t S t I t l t S t g t I t R t g t I t S 易感人群 I 感染人群 R 康复人群 l t c
  • 如何对字符串列表进行排序?

    在 Python 中创建按字母顺序排序的列表的最佳方法是什么 基本回答 mylist b C A mylist sort 这会修改您的原始列表 即就地排序 要获取列表的排序副本而不更改原始列表 请使用sorted http docs pyt
  • 如何使用 Django (Python) 登录表单?

    我在 Django 中构建了一个登录表单 现在我遇到了路由问题 当我选择登录按钮时 表单不会发送正确的遮阳篷 我认为前端的表单无法从 查看 py 文件 所以它不会发送任何 awnser 并且登录过程无法工作 该表单是一个简单的静态 html
  • 如何在SqlAlchemy中执行“左外连接”

    我需要执行这个查询 select field11 field12 from Table 1 t1 left outer join Table 2 t2 ON t2 tbl1 id t1 tbl1 id where t2 tbl2 id is

随机推荐

  • SqlDataAdapter

    ado net提供了丰富的数据库操作 在这些操作中SqlConnection和SqlCommand类是必须使用的 但接下来可以分为两类操作 一类是用SqlDataReader直接一行一行的读取数据库 第二类是SqlDataAdapter联合
  • IntelliJ IDEA2021.1 安装golang 插件

    golang插件安装前置条件 1 安装IntelliJ IDEA2021 1 安装步骤参考 IntelliJ IDEA安装操作步骤 2 已安装golang 安装环境参考 Go语言开发包 第一步 用户需要登陆 IDEA 的官网下载新版的gol
  • NotADirectoryError: [WinError 267] 目录名称无效。: ‘123456.txt‘

    NotADirectoryError WinError 267 目录名称无效 123456 txt 状况 python中出现如下情况 NotADirectoryError WinError 267 目录名称无效 123456 txt 问题
  • 全球及中国金属包装市场发展状况与竞争趋势研究报告2022版

    全球及中国金属包装市场发展状况与竞争趋势研究报告2022版 HS HS HS HS HS HS HS HS HS HS HS HS 修订日期 2021年11月 搜索鸿晟信合研究院查看官网更多内容 第一章 金属包装相关概述 1 1 包装和金属
  • 大数据项目之Flink实时数仓(数据可视化接口实现)

    设计思路 之前数据分层处理 最后把轻度聚合的结果保存到 ClickHouse 中 主要的目的就是提供即时的数据查询 统计 分析服务 这些统计服务一般会用两种形式展现 一种是为专业的数据分析人员的 BI 工具 一种是面向非专业人员的更加直观的
  • samba 4.6.5 从编译到配置

    为了防范永恒之蓝等samba病毒的传播 需要及时更新samba服务 本文介绍了在Ubuntu16 04版本上编译 配置samba 4 6 5 的方法 卸载当前系统中的samba sudo apt get remove samba commo
  • 数值分析 第七章 常微分方程的数值解法

    1 数值解法相关公式 1 1 为什么要研究数值解法 所谓数值解法 就是设法将常微分方程离散化 建立差分方程 给出解在一些离散点上的近似值 1 2 问题 7 1 一阶常微分方程初值问题的一般形式 y f x y a x by a begin
  • 借 __attribute__ 引入 The GNU C Reference Manual

    attribute 是 GNU C 规范的一个编译期关键字 话题文档主页 The GNU C Reference Manual GNU Project Free Software Foundation 在一般的Linux中 在文件 usr
  • selenium 360启动

    from selenium webdriver chrome options import Options from selenium import webdriver import time chrome options webdrive
  • MIPI接口中DPHY、CPHY简介及概要设计

    一 分类简介 MIPI是移动领域最主流的视频传输接口规范 目前应用最广泛的是MIPI DPHY和MIPI CPHY两组协议簇 另外还有MIPI MPHY 属于高速Serdes范畴 应用不那么广泛 1 MIPI DPHY 是MIPI的一种物理
  • RHEL 6 修改网卡名称

    RHEL Redhatenterprise linux 6 修改网卡名称 某些服务器安装redhat 6 4时 会自动把网卡名字设置为em1 em2等等 而不是以前的是eth0 等 但是flexlm只认识eth0的mac地址 不过我在虚拟机
  • Vue报错:Error in v-on handler: “TypeError: Cannot read properties of undefined (reading ‘skuId‘)“

    背景 当点击按钮时候 正常情况控制台的Network应该要发送一个变化量 现在控制台的Network不仅不显示 而且还报错 报错信息如下 vue runtime esm js c320 619 Vue warn Error in v on
  • 亚马逊云科技的区域和可用区概念解释

    对于刚开始接触AWS的用户而言 区域 Region 和可用区 Availability Zone AZ 这两个概念有点不好理解 初次接触时往往不知道它们跟我们日常说的数据中心是什么关系 然而区域和可用区是AWS中非常基础和重要的概念 因此我
  • 解决 jenkins 插件下载失败问题 - 配置 jenkins 插件中心为国内镜像地址

    参考资料 解决 jenkins 插件下载失败问题 配置 jenkins 插件中心为国内镜像地址 从 jenkins 官网上下载的 jenkins 在安装的过程中 会有安装插件一环 第一个为默认安装 第二个为手动 选择默认安装之后 会遇到 安
  • 线程休眠、礼让、等待

    线程的状态 线程中的方法 boolean isAlive 测试线程是否处于活动状态 setPriority int newPriority 更改线程优先级 static void sleep long millis 让指定线程休眠指定的毫秒
  • QML + KDDockWidget 实现 tabwidget效果( 窗口可独立浮动和缩放)

    前言 前面文章介绍过在QML中使用ListView实现TabBar标签拖拽交换位置效果 文章在这里 先在此基础上升级一下 结合KDDockWidget做一个可浮动的窗口效果 关于KDDockWidget的介绍 以前的文章有写过 可参考 qm
  • Tango和ROS在LabVIEW的联合测试

    环境 LabVIEW2018 32位 Tango ROS for LabVIEW Software v2 1 0 2 步骤 目标 变量传递顺序 Tango Client Tango Server ROS Publisher ROS Subs
  • 【MySQL】内置函数

    需要云服务器等云产品来学习Linux的同学可以移步 gt 腾讯云 lt gt 阿里云 lt gt 华为云 lt 官网 轻量型云服务器低至112元 年 新用户首次下单享超低折扣 目录 一 日期函数 1 函数用法 1 1current date
  • openGL之API学习(五十四)glDepthFunc

    指定深度测试比较的方法 如果满足深度测试条件则赢得深度测试并会被渲染出来 void glDepthFunc GLenum func func Specifies the depth comparison function Symbolic
  • Pytorch数据读取与预处理实现与探索

    在炼丹时 数据的读取与预处理是关键一步 不同的模型所需要的数据以及预处理方式各不相同 如果每个轮子都我们自己写的话 是很浪费时间和精力的 Pytorch帮我们实现了方便的数据读取与预处理方法 下面记录两个DEMO 便于加快以后的代码效率 根