PyTorch学习(2):数据加载机制
- Pytorch官方文档:https://pytorch-cn.readthedocs.io/zh/latest/
- Pytorch学习文档:https://github.com/tensor-yu/PyTorch_Tutorial
- 参考:https://blog.csdn.net/u011995719/article/details/85102770
文章目录
- PyTorch学习(2):数据加载机制
- 前言
- 1.Dataset类
- 2.构建自定义Dataset子类
- 3.DataLoader
- 4.PyTorch数据迭代过程
- 5.Torchvision的数据读取类
- 总结
前言
在处理任何机器学习问题之前都需要数据读取,并进行预处理。想让PyTorch读取自己构建的数据集,那么要先了解PyTorch读取图像的机制和流程,然后按照流程进行代码编写。
1.Dataset类
PyTorch读取图像,主要是通过torch.utils.data.Dataset抽象类,可以自己定义数据类继承和重写这个抽象类,非常简单,只需要定义__len__和__getitem__这两个函数。
Dataset类源码如下:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
重点在复写getitem函数,getitem接收一个index,然后返回图像数据和标签,这个index通常是数据List的索引index;len函数返回数据集的数量。
2.构建自定义Dataset子类
构建分类模型的数据读取Dataset子类–MyDataset类:
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):
with open(txt_path, 'r') as f:
lines = f.readlines()
imgs = []
for line in lines:
value = line.strip().split()
imgs.append((value[0], int(value[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
data, label = self.imgs[index]
img = Image.open(data).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.imgs)
通过上面的方式,可以定义模型需要的数据类,可以通过迭代的方式来获取每一个数据,但是很难实现batch、shuffle或者多线程去读取数据,所以PyTorch提供了torch.utils.data.DataLoader来定义一个迭代器。
3.DataLoader
DataLoader提供了对Dataset的读取操作,它是个iterable,可以进行相关迭代操作。常用参数有: batch_size(每个batch的大小), shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程)等等。
类定义为:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
可以看到,主要的参数有:
1)dataset: 即上述自定义的dataset;
2)Batch_size: 每次迭代加载样本的数量;
3)shuffle:对标签进行乱序;
4)num_workers:使用多进程加载的进程数,0代表不使用多进程。
5)collate_fn:是表示如何获取样本的,可以定义自己的函数来准确地实现想要的功能,默认的函数基本已满足需求,不必自定义实现采用策略。
样例: 自定义数据集读取
train_loader = torch.utils.data.DataLoader(MyDataset(), batch_size = 10, shuffle = True, num_workers = 0)
4.PyTorch数据迭代过程
伪代码1:
class CustomDataset(Dataset):
dataset = CustomDataset()
dataloader = Dataloader(dataset, ...)
for data in dataloader:
在for 循环里,总共有三点操作:
1)调用了dataloader 的__iter__() 方法, 产生了一个DataLoaderIter;
2)反复调用DataLoaderIter 的__next__()来得到batch, 具体操作就是, 多次调用dataset的__getitem__()方法 (如果num_worker>0就多线程调用),然后用collate_fn来把它们打包成batch。中间还会涉及到shuffle ,以及sample 的方法等。
3)当数据读完后, next()抛出一个StopIteration异常, for循环结束,dataloader 失效。
伪代码2:
class DataIterator(object):
def __init__(self, dataloader):
self.dataloader = dataloader
self.iterator = enumerate(self.dataloader)
def next(self):
try:
_, data = next(self.iterator)
except Exception:
self.iterator = enumerate(self.dataloader)
_, data = next(self.iterator)
return data[0], data[1]
train_dataprovider = DataIterator(dataloader )
for iters in range(1, train_interval + 1):
data, target = train_dataprovider.next()
注: 伪代码1和2为比较常见的PyTorch数据加载方式,数据格式根据自己要求随意指定,常见的分类任务的标签List为:图像路径+标签。
5.Torchvision的数据读取类
Torchvision 这个包中还提供了一个更高级的关于计算机视觉的数据读取类:ImageFolder,主要功能是处理图像,且要求图像是下面这种存放形式:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
即默认数据集已经按照分配好的类别已经分成了不同的文件夹,一种类型的文件夹下面只存放一种类型的图像。
之后采用下面的方式调用这个类:
import torchvision.datasets as dset
dset.ImageFolder(root="root folder path", [transform, target_transform])
1)root:指定图像存储的路径
2)transform: 一个函数,原始图像作为输入,返回数据增强转换后的图像。
3)target_transform: 一个函数,输入为target,输出对其的转换。
有以下成员变量:
1)self.classes - 用一个list保存 类名
2)self.class_to_idx - 类名对应的 索引
3)self.imgs - 保存(img-path, class) tuple的list
torchvision数据加载代码:
import torchvision.datasets as datasets
train_dataset = datasets.ImageFolder(
args.train_dir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
transforms.RandomHorizontalFlip(0.5),
ToBGRTensor(),
])
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=1, pin_memory=use_gpu)
总结
至此,基于PyTorch框架的数据加载机制已基本了解。(1)构建自己数据集的Dataset子类;(2)使用DataLoder来构建数据迭代器;(3)使用torchvision库来加载数据。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)