PyTorch学习(2):数据加载机制

2023-05-16

PyTorch学习(2):数据加载机制

  • Pytorch官方文档:https://pytorch-cn.readthedocs.io/zh/latest/
  • Pytorch学习文档:https://github.com/tensor-yu/PyTorch_Tutorial
  • 参考:https://blog.csdn.net/u011995719/article/details/85102770

文章目录

  • PyTorch学习(2):数据加载机制
  • 前言
    • 1.Dataset类
    • 2.构建自定义Dataset子类
    • 3.DataLoader
    • 4.PyTorch数据迭代过程
    • 5.Torchvision的数据读取类
  • 总结


前言

  在处理任何机器学习问题之前都需要数据读取,并进行预处理。想让PyTorch读取自己构建的数据集,那么要先了解PyTorch读取图像的机制和流程,然后按照流程进行代码编写。


1.Dataset类

PyTorch读取图像,主要是通过torch.utils.data.Dataset抽象类,可以自己定义数据类继承和重写这个抽象类,非常简单,只需要定义__len__和__getitem__这两个函数

Dataset类源码如下:

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
	raise NotImplementedError
def __len__(self):
	raise NotImplementedError
def __add__(self, other):
	return ConcatDataset([self, other])

重点在复写getitem函数,getitem接收一个index,然后返回图像数据和标签,这个index通常是数据List的索引index;len函数返回数据集的数量。

2.构建自定义Dataset子类

构建分类模型的数据读取Dataset子类–MyDataset类:

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, txt_path, transform = None, target_transform = None):
        with open(txt_path, 'r') as f:
            lines = f.readlines()
        imgs = []
        for line in lines:
            value = line.strip().split()
            imgs.append((value[0], int(value[1])))
        self.imgs = imgs 
        self.transform = transform
        self.target_transform = target_transform
        
    def __getitem__(self, index):
    	data, label = self.imgs[index]
    	img = Image.open(data).convert('RGB') 
    	if self.transform is not None:
    		img = self.transform(img) 
    	return img, label
    
    def __len__(self):
    	return len(self.imgs)

通过上面的方式,可以定义模型需要的数据类,可以通过迭代的方式来获取每一个数据,但是很难实现batch、shuffle或者多线程去读取数据,所以PyTorch提供了torch.utils.data.DataLoader来定义一个迭代器。

3.DataLoader

DataLoader提供了对Dataset的读取操作,它是个iterable,可以进行相关迭代操作。常用参数有: batch_size(每个batch的大小), shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程)等等。
类定义为:

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

可以看到,主要的参数有:
  1)dataset: 即上述自定义的dataset;
  2)Batch_size: 每次迭代加载样本的数量;
  3)shuffle:对标签进行乱序;
  4)num_workers:使用多进程加载的进程数,0代表不使用多进程。
  5)collate_fn:是表示如何获取样本的,可以定义自己的函数来准确地实现想要的功能,默认的函数基本已满足需求,不必自定义实现采用策略。

样例: 自定义数据集读取

train_loader = torch.utils.data.DataLoader(MyDataset(), batch_size = 10, shuffle = True, num_workers = 0)

4.PyTorch数据迭代过程

伪代码1:

class CustomDataset(Dataset):
   # 定义自己的dataset子类

dataset = CustomDataset()
dataloader = Dataloader(dataset, ...) # 构建数据迭代器

for data in dataloader: # 循环迭代
   # training...

在for 循环里,总共有三点操作:
  1)调用了dataloader 的__iter__() 方法, 产生了一个DataLoaderIter;
  2)反复调用DataLoaderIter 的__next__()来得到batch, 具体操作就是, 多次调用dataset的__getitem__()方法 (如果num_worker>0就多线程调用),然后用collate_fn来把它们打包成batch。中间还会涉及到shuffle ,以及sample 的方法等。
  3)当数据读完后, next()抛出一个StopIteration异常, for循环结束,dataloader 失效。

伪代码2:

class DataIterator(object):  # 构建数据迭代器
    def __init__(self, dataloader):
        self.dataloader = dataloader
        self.iterator = enumerate(self.dataloader)

    def next(self):
        try:
            _, data = next(self.iterator)
        except Exception:
            self.iterator = enumerate(self.dataloader)
            _, data = next(self.iterator)
        return data[0], data[1]

train_dataprovider = DataIterator(dataloader )

for iters in range(1, train_interval + 1):
data, target = train_dataprovider.next() #调用next函数来获取当前batch的数据

注: 伪代码1和2为比较常见的PyTorch数据加载方式,数据格式根据自己要求随意指定,常见的分类任务的标签List为:图像路径+标签

5.Torchvision的数据读取类

Torchvision 这个包中还提供了一个更高级的关于计算机视觉的数据读取类:ImageFolder,主要功能是处理图像,且要求图像是下面这种存放形式:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

即默认数据集已经按照分配好的类别已经分成了不同的文件夹,一种类型的文件夹下面只存放一种类型的图像。

之后采用下面的方式调用这个类:

import torchvision.datasets as dset
dset.ImageFolder(root="root folder path", [transform, target_transform])

  1)root:指定图像存储的路径
  2)transform: 一个函数,原始图像作为输入,返回数据增强转换后的图像。
  3)target_transform: 一个函数,输入为target,输出对其的转换。

有以下成员变量:
  1)self.classes - 用一个list保存 类名
  2)self.class_to_idx - 类名对应的 索引
  3)self.imgs - 保存(img-path, class) tuple的list

torchvision数据加载代码:

import torchvision.datasets as datasets
# 数据读取
train_dataset = datasets.ImageFolder( 
    args.train_dir,
    transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
        transforms.RandomHorizontalFlip(0.5),
        ToBGRTensor(),
    ])
)
# 数据迭代
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=True,
    num_workers=1, pin_memory=use_gpu)

总结

  至此,基于PyTorch框架的数据加载机制已基本了解。(1)构建自己数据集的Dataset子类;(2)使用DataLoder来构建数据迭代器;(3)使用torchvision库来加载数据。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch学习(2):数据加载机制 的相关文章

随机推荐