PyTorch:Torchvision的简单介绍与使用

2023-11-20

安装

pip install torchvision

torchvision独立于pytorch,专门用来处理图像,通常用于计算机视觉领域。

重点介绍torchvision最常用的三个包:

models:提供了很多常用的训练好的网络模型,我们可以直接加载并使用,如Alexnet、ResNet等。

datasets:提供了(1)一些常用的图片数据集,如MNIST、COCO等(2)加载自己的数据集的常用方法,目前只有DatasetFolder、ImageFolder、VisionDataset三个方法。

transforms:提供了一些常用的图像转换处理操作,主要针对Tensor或PIL Image进行操作。

下面我们简单介绍一下这些包的用法,更多详细的内容可以查阅Torchvision的官方文档源代码

一、models包

Torchvision提供的models包包括了很多常用模型的定义,继承自torch.nn.Module类,使用起来十分简单方便。以分类模型AlexNet为例,用法如下。

AlexNet类的初始化方法为

torchvision.models.alexnet(pretrained=False, progress=True)

参数的含义:

  1. pretrained:bool类型。设置为True时,返回一个使用ImageNet数据集预训练过的模型,会下载权重并存储在缓存路径下(缓存路径在TORCH_MODEL_ZOO的环境变量中设置);设置为False时,只导入模型不导入参数,权重是随机初始化的。
  2. progress:bool类型。设置为True时,会显示下载进度条。

除了个别比较复杂的模型,大多数模型函数的用法和参数与该函数类似,不再多举例。

二、transforms包

torchvision.transforms包是torchvision中用来进行图像变换(transform)的包,如裁剪、缩放、翻转等。

transforms包操作的对象通常是PIL Image或torch.tensor类型的图像数据,包中的大多数变换方法对两种输入都适用,但有一些变换只可以输入PIL Image或只可以输入torch.Tensor。下文中我们会对一些常用的方法进行说明,但更详细的内容还是要参见官方文档

1. 多个变换的组合(Compose类)

Compose类的作用是将多个变换组合在一起,初始化方法为

torchvision.transforms.Compose(transforms)

参数的含义:

  1. transform:多个变换组成的列表(list)。

举例

transforms.Compose([
	transforms.CenterCrop(10),
	transforms.Pad(10, 0),
	transforms.ToTensor(),
	transform.Normalize(mean=[0.5, 0.5, 0.5], std=[0.1, 0.1, 0.1])
	transforms.ConvertImageDtype(torch.float),
])

上面的例子中,对一张图像以此进行了以下步骤的变换:

  1. 以输入图像的中心为裁切中心,对图像进行裁切
  2. 对上一步得到的图像的边缘进行Pad填充
  3. 将上一步得到的图像转换成torch.Tensor格式
  4. 将上一步得到的tensor图像进行标准化转换
  5. 将上一步得到的tensor图像转换成torch.float类型

2. 常用的变换(输入可以是PIL Image或torch.Tensor)

下面是一些常用的图像变换,这些变换既可以针对PIL Image格式的输入进行,也可以针对torch.Tensor进行。

transforms.CenterCrop(size)

以图像中心为基点进行裁切。如果图像本身就小于裁切的尺寸,则先将图像Pad填充上0,再进行裁切。
参数:

  1. size:int类型或序列(Sequence)类型,表示裁切后的大小。如果参数size是一个int型的数字,则才切成变成为size大小的正方形;如果参数是序列类型(h, w),则将图像裁切成h×w大小。
transforms.Pad(padding, fill=0, padding_mode='constant')

对图像边缘进行pad填充,参数的含义如下:

  1. padding:int类型或 Sequence类型,可选。如果是一个int类型的数字,则表示每个边缘都进行pad填充;如果是一个长度为2的Sequence,则表示分别对左右边界、上下边界进行不同宽度的pad填充;如果是一个长度为4的Sequence,则表示分别对左、上、右、下进行不同的pad填充。
  2. fill:数字类型或str类型或元组(tuple)类型,只有当padding_mode参数为“constant”时有效。如果是一个数值,则使用该数值所代表的颜色进行填充;如果是一个长度为3的元组,则使用元组中的元素分别对R、G、B三个通道进行填充。torch.Tensor只支持数值类型,PIL Image只支持int、str或元组类型。
  3. padding_mode:str类型,表示填充的类型,可以是下面的类型之一
    - constant:使用一个常量进行填充,这个常量由fill参数进行指定。
    - edge:使用图片最外侧边缘的像素颜色进行填充。
    - reflect:使用图片的镜像图像对边缘进行填充,最边缘的像素不重复。例如,对[1, 2 , 3, 4]的左右两侧进行reflect填充的结果为[3, 2, 1, 2, 3, 4, 3, 2]。
    - symmetric:使用图片的镜像图像对边缘进行填充,最边缘的像素重复一次。例如,对[1, 2 , 3, 4]的左右两侧进行reflect填充的结果为[2, 1, 1, 2, 3, 4, 4, 3]。
transforms.RandomCrop(size, padding=None, pad_if_needed=False,
						fill=0, padding_mode='constant')

以图像上的任意一个位置作为基点进行裁切。参数的含义如下:

  1. size:用法与CenterCrop的size参数相同。
  2. pad_if_needed:boolean类型。当指定当图像小于裁切后的大小时是否进行pad填充。
  3. padding、fill、padding_mode:用法与Pad的 padding、fill、padding_mode参数相同。
transforms.Grayscale(num_output_channels=1)

将图像转换成灰度图。输入的参数含义如下:

  1. num_output_channels:int类型,只能取1或3。如果为1,表示返回的图像只有一个通道;如果为3,表示返回的图像有三个通道,但r == g == b。

3. 常用的变换(输入只能是torch.Tensor)

transforms.Normalize(mean, std, inplace=False)

对tensor图片进行标准化,均值为mean,标准差为std。不支持PIL Image。

假设图片有n个通道,由于这个标准化操作对图片的每一个通道都要进行,因此mean和std都是长度为n的序列,对应每个通道的标准化参数。

参数含义:

  1. mean:序列类型,表示每个通道的均值。
  2. std:序列类型,表示每个通道的标准差。
  3. implace:bool类型,可选。是否让这个操作in-place。
transforms.ConvertImageDtype(dtype)

将一张tensor图片转换成指定的数据类型(data type, dtype),如torch.float32、torch.int等。该方法不适用于PIL Image。参数的含义如下

  1. dtype:torch.dtype类型,表示希望转换成的数据类型。

4. 数据类型转换

下面是PIL Image和torch.Tensor两种数据类型相互转换的方法。

transforms.ToPILImage(mode=None)

将tensor或numpy.ndarray类型的数据转换成PIL Image,这个转化不支持torchscript。如果tensor或ndarray数据的大小为C×H×W大小,则转换后的PIL Image大小为H×W×C。

参数含义:

  1. mode:PIL Image mode类型,可选,表示输入数据的色彩空间或像素深度。如果mode是None,则实际的mode与输入图像的通道数有关:若输入有四个通道,则mode为RGBA;若输入有三个通道,则mode为RGB;若输入有两个通道,则mode为LA;若输入有一个通道,则根据数据类型mode被设置为int、short或float等数值类型。

解释一下,像素深度(Pixel Depth)是存储每个像素所用的位数,色彩空间(Color Space)是指一些常用的颜色模型构成色彩的集合,如RGB、CMYK、LAB等。

torchvision.transforms.ToTensor()

将PIL Image或numpy.ndarray类型的数据转换成tensor,这个转化不支持torchscript。

4. 举例:对单个的样本进行变换

这些transform类的一个应用场景是,假设我们有一个图片样本,现在想直接对它进行变换,应该如何调用这些类来完成这个操作。

假设我们读取了一张PIL Image格式的图片,现在希望把它转换成Tensor类型,则方法为

from PIL import Image
from torchvision import transforms

image = Image.open(r"/path/to/image.jpg")
img_tensor = transform.ToTensor()(image)

即可得到转换后的Tensor。

三、datasets包

datasets包中所有的类几乎都直接或间接继承自torch.utils.data.Dataset类,因此,借由datasets包得到的数据集都可以再传递给torch.utils.data.DataLoader,由它进行多线程并行加载样本数据。例如并行加载一个ImageNet数据集的代码如下

imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads)

1. 加载常用数据集

以加载MNIST数据集为例,MNIST类的初始化方法如下

torchvision.datasets.MNIST(root, train=True, transform=None, 
						target_transform=None, download=False)

它的参数含义如下:

  1. root:string类型,表示存放MNIST数据集的根目录。如果加载的是训练集就是这个参数就表示训练集的根目录,如果加载的是测试集这个参数就表示测试集的根目录。
  2. train:bool类型,可选。为True是表示加载的是训练集,为False时表示加载测试集。
  3. download:bool类型,可选。为True时表示从互联网上下载数据集并保存在root参数指定的目录下,如果数据集已经下载过则不会重复下载。
  4. transform:可调用函数类型,可选。使用torchvision.transforms中定义的类对输入的图片进行变换,输入的图片必然是PIL Image格式的。
  5. target_transform:可调用函数类型,可选。使用torchvision.transforms中定义的类对目标(target)进行变换。

这里我们重点介绍一下transform和target_transform两个参数。两个参数都要使用前文介绍的transforms包,可以对输入或目标进行一个变换或一组变换,返回变换后的结果。例如

from torchvision import datasets
from torchvision import transforms

tran = transforms.Compose([
  transforms.CenterCrop(50),
  transforms.ToTensor(),
])

mnist = datasets.MNIST(root="path/to/MNIST", train=True, download=True, transform=tran)

在上面的代码中,我们定义了一组变换tran,它会将图片先裁切成50×50大小,再转换成tensor格式。

然后我们读取"path/to/MNIST"路径下的数据集MNIST,如果数据集不存在就从网络上下载到该路径下,对于数据集中的每个图片,我们都进行tran定义的变换。

target_transform参数的用法与transform参数类似,只不过是针对数据的目标进行变换。例如,在语义分割问题中,可以对分割的mask(Segmentation mask)进行变换;在分类问题中,也可以对分类标签进行变换。

2. 加载用户自定义的数据集(ImageFolder类)

Torchvision提供了一个常用的加载自定义数据集的类ImageFolder,使用它加载图片数据集时,数据集需要以这样的方式进行组织:将不同类别的数据分别放在不同目录下,每个目录的名字就是数据的标签。举个例子如下(注:图片本身可以随意起名,但必须保证具有同一类的图片放在一个目录下)

├──cat
¦	├──cat_001.jpg
¦	├──cat_002.jpg
¦	└──……
├──dog
¦	├──dog_001.jpg
¦	├──dog_002.jpg
¦	├──dog_003.jpg
¦	└──……
└──……

随便找了几张图片作为这个示例。

ImageFolder类的初始化方法如下

torchvision.datasets.ImageFolder(root, transform=None, target_transform=None
								loader=default_loader, is_valid_file=)

参数的含义如下

  1. root:string类型。存放数据集的根目录。
  2. transform、target_transform:用法和含义同MNIST类的transform、target_transform参数。
  3. loader:可调用函数类型,可选。加载图片的方式,默认读取RGB格式的PIL Image类型的图像。
  4. is_valid_file: 可调用函数类型,可选。用来检查一个Image文件是否是空的(检查该图片是否损坏)

label就是一张图像的标签,也就是将子文件夹的名字以字典的形式存储起来,即{类名 : 序号},序号从0开始顺序向后计数。上文中的例子中的label就是

{cat : 0}
{dog : 1}
……

常用的成员变量为

  1. classes:list类型,类名
  2. class_to_idx:list类型,label。
  3. imgs:list类型,由图片名称和它的类型组成。
from torchvision.datasets import ImageFolder

dataset = ImageFolder("data/")

print(dataset.classes)
print(dataset.class_to_idx)
print(dataset.imgs)

输出为

dataset[i]可以得到对应图像的数据。dataset[i][0]是PIL图像类型,dataset[i][1]是该图像的类别。

print(dataset[0])

输出为

四、utils包

除了上述三个包外,utils包也是经常用到的包,里面包括了一些在计算机视觉领域经常用到的操作。

1. make_grid函数

函数原型

torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False,
							value_range=None, scale_each=False, pad_value=0) → Tensor

该函数的作用是将若干张图像拼成一幅图像。如下图所示

各参数的含义如下:

  1. tensor:torch.Tensor或列表(list)类型。如果是Tensor类型,则形状必须是四维B×C×H×W,用来表示一组mini-batch的Tensor;如果是list类型,则必须是一组大小相同的图片。
  2. nrow:int类型,可选。表示组成的大图中每一行包括的小图数量。
  3. padding:int类型,可选。每张小图四周的padding的大小。
  4. normalize:bool类型,可选。设置为True时,图像将会被标准化,也就是将图像限制在(0, 1)范围内,最大值和最小值由这个batch中所有图像的最大值最小值确定。
  5. value_range:元组(tuple)类型,可选。用元组(min, max)作为标准化图像时用到的最大值和最小值,默认情况下这两个值是由图像计算得到的。
  6. scale_each:bool类型,可选。设置为True时,计算标准化使用的最大值最小值时,每个图片使用自己的最值,而不是该batch下所有图片使用同一个最值计算。
  7. pad_value:float类型,可选。表示pad填充的像素的颜色。

2. save_image函数

torchvision.utils.save_image(tensor, fp, format=None)None

直接将给定的Tensor保存成图片。
参数的含义:

  1. tensor:Tensor或list类型。表示需要被保存的图片,如果给定的是一个mini-batch的tensor,则会自动调用make_grid函数将这些图片组合成网格形式再保存。
  2. fp:string类型或文件对象(file object)类型。将图片保存到该参数指定的文件路径或文件对象。
  3. format:可选。fp如果是文件名,则忽略;fp如果是文件对象,则该参数经常被使用。

3. draw_bounding_boxes函数

torchvision.utils.draw_bounding_boxes(image, boxes, labels=None, colors, fill=False,
										width=1, font=None, font_size=10) → Tensor

在给定的图像中画出bounding box。输入的图像必须是uint8类型,且范围在0到255之间。

参数的含义如下:

  1. image:Tensor类型。表示被操作的图片,Tensor的大小为C×H×W,dtype类型为uint8。
  2. boxes:Tensor类型。如果一张图像上有N个bounding box,则boxes参数的大小为(N, 4),box的四个点以(xmin, ymin, xmax, ymax)的形式进行组织。注意,这个坐标是图像上的绝对坐标。
  3. labels:list类型,list中每个成员都是str类型。表示一系列bouding box的标签。
  4. colors:可以使用一个str或使用元组(r, g, b)将所有bounding box指定成同一个颜色,也可以使用一个list对每个bouding box指定不同的颜色,list的每个成员也用str或(r, g, b)表示。
  5. fill:bool类型。设定为True时将bounding box填充为指定的颜色。
  6. width:int类型。表示bounding box的宽度。
  7. font:str类型。表示用到的字体。
  8. font_size:int类型。字体的大小,以点为单位。

返回一张标注了bounding box的图像,类型dtype为uint8。

4. draw_segmentation_masks函数

torchvision.utils.draw_segmentation_masks(image, masks, alpha=0.8, colors=None)

在给定的RGB图像上画出语义分割的mask。输入的图像必须是uint8类型,且范围在0到255之间。

参数的含义:

  1. image:Tensor类型,大小是(3, H, W),类型是uint8。
  2. masks:Tensor类型,大小是(num_masks, H, W)或(H, W),dtype是bool类型。
  3. alpha:float类型,大小在0到1之间,表示mask的透明度,0表示完全透明,1表示不透明。
  4. colors:list类型或None。list类型来指定每个mask的颜色,颜色可以用字符串"red"或"#FF00FF"来表示,也可以用RGB元组(0, 255, 255)表示。如果masks参数的形状是(H, W),则可以不使用list而只传入一个颜色。如果colors是None,则对每个mask随机赋予一个颜色。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch:Torchvision的简单介绍与使用 的相关文章

随机推荐

  • Webpack插件核心原理

    引言 围绕 Webpack 打包流程中最核心的机制就是所谓的 Plugin 机制 所谓插件即是 webpack 生态中最关键的部分 它为社区用户提供了一种强有力的方式来直接触及 webpack 的编译过程 compilation proce
  • knime简介_KNIME简介

    knime简介 Data Science is abounding It considers different realms of the data world including its preparation cleaning mod
  • ubuntu本地安装jdk17

    下载 wget https download oracle com java 17 archive jdk 17 0 7 linux x64 bin tar gz 解压 tar zxvf jdk 17 0 7 linux x64 bin t
  • 【c++】Lambda表达式

    Lambda表达式 Lambda表达式是C 中的匿名函数 允许你在需要时定义和使用小型函数 语法 Lambda表达式的基本语法如下 scssCopy code 捕获列表 参数列表 gt 返回类型 Lambda函数体 捕获列表定义了Lambd
  • PAJ7620U2手势识别——配置0x00寄存器(3)

    文章目录 前言 一 为啥要配置0x00寄存器 二 配置步骤 1 单个读操作步骤图 2 模块状态转移图绘制 3 模块波形图绘制 4 上板验证 5 参考代码 总结 前言 在前面的教程中 小编带领各位读者学习了如何通过I2C协议去唤醒PAJ762
  • vue列表跳转详情,记录列表滚动不变

    记录主元素 computed elTable function return document getElementsByClassName layout content 0 当引入keep alive的时候 页面第一次进入 钩子的触发顺序
  • 回溯法展开状态空间树

    解空间 假设问题的解能用n元组 X1 Xn 表示 其中Xi取自某个有穷集Si 这些n元组构成的集合称为问题的解空间 假设集合Si的大小 Si mi 则解空间的大小m m1 m2 mn 注意这里解空间的大小取决于元组中每个元素的可能取值的数量
  • STM32内部参照电压VREFIN的使用

    一 STM32的内部参照电压VREFINT和ADCx IN17相连接 它的作用是相当于一个标准电压测量点 和MSP430不一样 内部参照电压VREFINT只能出现在主ADC1中使用 内部参照电压VREFINT与参考电压不是一回事 ADC的参
  • 详解git pull和git fetch的区别:

    前言 在我们使用git的时候用的更新代码是git fetch git pull这两条指令 但是有没有小伙伴去思考过这两者的区别呢 有经验的人总是说最好用git fetch git merge 不建议用git pull 也有人说git pul
  • @Slf4j 实现日志输入到外部文件

    1 添加一个配置文件 src main resources logback spring xml
  • 使用html+css+javaScript 完成计算器

    一 先用html与css搭建骨架 思路 将计算器的数字按钮放进一个表格里 再通过css修饰 然后对指定的数字按钮或功能按钮添加事件 将需要计算的式子放进一个字符串里 最后通过全局方法eval 计算出来 html的骨架搭建 这里的用一个 di
  • 西门子300系列基本逻辑编程:手自动选择程序及自定义脉冲模块的使用

    西门子内置脉冲发生器 M0 0 0 1S M0 1 0 2S M0 2 0 4S M0 3 0 5S M0 4 0 8S M0 5 1 0S M0 6 1 6S M0 7 2 0S 案例 手自动选择程序 控制要求 I0 0是手自动选择开关
  • Vmware虚拟机和主机之间复制、粘贴内容、拖拽文件的详细方法

    Vmware正确安装完linux虚拟机之后 这里以Ubuntu为例 其他linux或windows系统也是类似的 如果你使用的默认配置 正常情况下就可以复制 粘贴和拖拽内容的 双方向都是支持的 如果不能复制和拖拽一般是vmware tool
  • mongodb入门(2)

    目录 一 mongodb入门 1基础概念 2连接mongodb 3 数据库 4 集合 5 文档 1 插入文档 2 更新文档 3删除文档 4查询文档 6用户 1 创建用户 2查询用户 3删除用户 4修改用户 5修改密码 一 mongodb入门
  • 10 个牛逼的单行代码编程技巧,你会用吗?

    标题本文列举了十个使用一行代码即可独立完成 不依赖其他代码 的业务逻辑 主要依赖的是Java8中的Lambda和Stream等新特性以及try with resources JAXB等 1 对列表 数组中的每个元素都乘以2 Range是半开
  • Spring 中如何使用SpEL表达式语言呢?

    转自 Spring 中如何使用SpEL表达式语言呢 SpeL简介说明 SpeL Spring Expression Language是一种功能强大的表达式语言 支持运行时查询和操作对象图 使用SpeL可采用最少的代码 完成大量的工作 注意事
  • vim菜鸟学习-中级篇2(经典配置)

    参考资料 http www cnblogs com striveford archive 2011 02 09 1950369 html http blog csdn net xjanker2 article details 5832784
  • linux环境下安装jmeter

    linux压力机安装jmeter 1 在Linux服务器先安装jdk 配置环境变量 2 下载 apache jmeter 5 4 1tgz https jmeter apache org download jmeter cgi 上传到服务器
  • UE4 解决景深效果闪烁问题

    原因 1 模型的垂直竖线 造成抗锯齿算法对竖线的渲染计算 处于一种不稳定的状态 因此闪烁 解决办法 使用LOD 用贴图去替代线条模型 2 材质的法线贴图 当法线贴图含有垂直竖线的纹理效果 也会造成闪烁 比如这种幕墙材质 解决办法 关闭或动态
  • PyTorch:Torchvision的简单介绍与使用

    安装 pip install torchvision torchvision独立于pytorch 专门用来处理图像 通常用于计算机视觉领域 重点介绍torchvision最常用的三个包 models 提供了很多常用的训练好的网络模型 我们可