本篇文章是翻译:https://deeplizard.com网站中的关于Pytorch学习的文章,供学习使用。
原文地址为:https://deeplizard.com/learn/video/8n-TGaBZnk4
使用Pytorch进行提取(E)、转换(T)和加载(L)
欢迎回到基于PyTorch的神经网络编程系列课程。在这一部分,我们将编写我们的第一个代码。
我们将使用torchvision和PyTorch的计算机视觉包进行一个简单的提取、转换和加载过程的展示。废话不多说,现在开始。
概述
在这个项目中,我们将遵循四个步骤:
- 准备数据
- 建立模型
- 训练模型
- 分析模型的结果
ETL过程(提取、转换和加载)
在这篇文章中,我们从准备数据开始,为了准备数据,我们将会遵循ETL过程。
- 从数据源中提取数据。
- 将数据转换成理想的格式。
-
将数据加载到合适的结构(模型)中。
ETL过程可以被认为是一个分形过程,因为它可以应用在不同的尺度上。这个过程可以应用在小的尺度上,就像一个单一的程序。或者应用在一个大的尺度上,如在一个企业级别的巨大系统的每个小部分。
如果你想知道更多关于通用数据科学,请查看数据科学文章 ,文章中有详细解释。
一旦我们完成了数据的ETL过程,我们就可以进行模型的创建和训练了。PyTorch中有一些内置的包和类,可以很方便的进行ETL过程。
PyTorch导包
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
下面的表格将介绍上面的各个包。
包 |
介绍 |
torch |
最高级别的PyTorch包和tensor库。 |
torch.nn |
包含模型和建立神经网络扩展类的子包。 |
torch.optim |
包含SGD和Adam等标准优化操作到子包。 |
torch.nn.functional |
包含建立神经网络的典型操作的函数接口。(损失函数和卷积等。) |
torchvision |
为计算机视觉提供通用数据集通道、模型结构和图像转换的包。 |
torchvision.transforms |
为图像处理提供普通转换的接口。 |
导入其他的包
下面导入使用Python所需要的标准数据包。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
#from plotcm import plot_confusion_matrix
import pdb
torch.set_printoptions(linewidth=120)
pdb
是一个编译器,import
(导入)一个本地文件,我们将会在以后的文章进行混淆矩阵(confusion matrix)的介绍。最后一行设置PyTorch打印语句的打印选项。
现在我们可以进行数据的准备了。
使用PyTorch进行数据的准备
我们准备数据是遵循下面步骤:
- 提取:从数据源中获取
Fashion-MNIST
数据集。
- 转换:将我们的数据转换成
tensor
类型。
- 加载:将我们的数据放在一个对象中,便于我们访问。
PyTorch提供给我们两个类:
类 |
介绍 |
torch.utils.data.Dataset |
代表一个数据集的抽象类。 |
torch.utils.data.DataLoader |
包装一个数据集并提供对底层数据的访问。 |
抽象类
是Python中必须执行的方法,所以我们可以通过创建一个子类来创建一个自定义数据集,该子类用来扩展 ‘数据集类(抽象类)’ 的功能。
使用PyTorch创建一个自定义类,我们通过创建子类来扩展数据集类的请求方法。在做这个的时候,我们的子类可以被传入PyTorch DataLoader
对象。
我们将会使用fashion-MNIST数据集(内置在torchvison
包中),因为数据集内置在数据包中,所以我们不必对我们的项目准备fashion-MNIST数据包。只需要知道Fashion-MNIST内置数据集正在幕后工作就可以了。
所有的数据集类的子类都必须进行重写,__len__
提供一个数据集的大小,__getitem__
支持从0
至len(self)
的整数索引。
PyTorch 的Torchvision包
torchvision数据包为我们提供获得资源的通道。我们可以获得一下资源:
- Datasets(像
MNIST
和Fashion-MNIST
数据包)
- Models(模型,
VGG16
)
- Transforms
- Utils
计算机视觉
所有的这些资源都和计算机视觉相关。
我们已经在之前的文章中学习了Fashion-MNIST数据集的介绍,通过arXiv文献我们可以知道Fashion-MNIST的作者主要的意图是制作一个可以代替MNIST的数据集。
这个意图使得在使用PyTorch在添加Fashion-MNIST数据集时只需要更改URL地址就可以了。
在这个案例PyTorch案例中,这个FashionMNIST
数据集扩展了MNIST
数据集并重写了urls。
这里是这个类定义的torchvision
资源代码:
class FashionMNIST(MNIST):
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
Args:
root (string): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
urls = [
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
]
现在让我们看看如何利用torchvision
。
PyTorch 数据集类
为了使用torchvision
获取fashionMNIST数据集的实例,我们只需要做如下操作:
train_set = torchvision.datasets.FashionMNIST(
root='./data'
,train=True
,download=True
,transform=transforms.Compose([
transforms.ToTensor()
])
)
以前root的参数是'./data/FashionMNIST'
,然而,现在因为torchvision
的更新而改变了。
以上参数的解释:
参数 |
描述 |
root |
数据在磁盘上的位置 |
train |
数据集是否是用来进行训练的,True表示数据集用来训练 |
dowenload |
这个数据是否应该被下载 |
transform |
应该在数据集元素上执行的转换的组合。 |
如果我们想让我们的图像转换成tensors格式,我们可以使用内置函数transforms.ToTensor()
。因为数据集被用来进行训练,所以我们将其命名为train_set
.
当我们第一次运行这个代码的时候,这个Fashion-MNIST数据集将会被下载。随后在数据下载之前进行系统调用检测。因此,我们不用担心数据集会被下载两次。
DataLoader 类
为我们的训练集创建一个DataLoader,我们只需要做:
train_loader = torch.utils.data.DataLoader(train_set
,batch_size=1000
,shuffle=True
)
我们只需要将train_set作为一个参数传入。现在我们可以使用这个加载器而不是手动进行我们的测试:
-
batch_size
:(在我们的案例中我们使用1000)。
-
shuffle
:(参数设置为True)
-
num_workers
:(默认值是0,意思是我们将使用主程序)
总结
在ETL的观点中,我们完成了提取和使用torchvision进行转换。当我们创建数据集时:
- Extract:从网站中提取未加工的数据。
- Transform:将未加工的(原始)图像转换成tensor格式。
- Load:这个
train_set
包装了数据加载器,并给我们提供了通往底层数据的通道。
现在,我们对PyTorch提供的torchvision有了很好的理解,并且也知道了如何使用数据集和数据加载器进行ETL操作。
在在下一篇文章中,我们将会看到我们可以进行#¥%¥@%¥等一些操作。期待与你的下一次见面。