使用pytorch加载数据集和对数据集进行处理

2023-05-16

目录

1.torchvision中加载数据集

2.重写Dataset类加载数据集

3.transforms

4.Dataloader对数据进一步处理


1.torchvision中加载数据集

官方文档给出的数据

 

下面以CIFAR数据集为例子:

torchvision.datasets.CIFAR10(root: str, train: bool = True,
 transform: Optional[Callable] = None, target_transform: 
Optional[Callable] = None, download: bool = False)
  • root:表示数据集的路径
  • train:表示是否为训练集,为True表示为训练数据集,否则为测试集。
  • transform:表示对数据集进行转换,下面已经对该功能进行了说明。
  • target_transform:表示对target进行数据转换。
  • download:是否下载,如果为True的话,表示从网上进行下载该数据集,否则从已有的文件目录下面获取。
import os
import torch
import numpy as np
from PIL import Image
from torchvision import datasets,transforms

#数据集的预处理
transform=transforms.Compose([
    transforms.ToTensor()
])

#下载数据集
root='myDataset/CIFAR10'
#将数据集下载之后保存到root路径下(如果数据集已经存在,则不从网上下载,直接从给定的root路径下获取)
train_data=datasets.CIFAR10(root=root,train=True,transform=transform,download=True)
test_data=datasets.CIFAR10(root=root,train=False,transform=transform,download=True)

print('trainSize: {}'.format(len(train_data)))
print('testSize: {}'.format(len(test_data)))

#下载数据集
root='myDataset/CIFAR10'
#将数据集下载之后保存到root路径下(如果数据集已经存在,则不从网上下载,直接从给定的root路径下获取)
test_data=datasets.CIFAR10(root=root,train=False,download=True)

#显示图片的类别
print('图片包含类别: {}'.format(train_data.classes))

#显示图片
imgOne,target=test_data[0]
imgOne.show()
#查看图片所属类别
print('class: {}'.format(test_data.classes[target]))

 

 

2.重写Dataset类加载数据集

官方文档torch.utls.data.Dataset

 

import os
import pathlib

from PIL import Image
from torch.utils.data import Dataset

class myDataset(Dataset):
    def __init__(self,img_path):
        self.data_dir=pathlib.Path(img_path)
        self.dataset=list(self.data_dir.glob('*/*.jpg'))

    #根据索引index获取数据,index是根据数据集的顺序来获取的
    def __getitem__(self, index):
        img=self.dataset[index]
        imgTo=Image.open(img)
        return imgTo
    
    #获取数据集的大小
    def __len__(self):
        #统计flower_photos文件夹下面所有的图片数据集数量
        self.len=len(list(self.data_dir.glob('*/*jpg')))
        return self.len

if __name__ == '__main__':
    mydataset=myDataset(img_path=r'E:\myDataset\flower_photos')
    print('dataSize: {}'.format(len(mydataset)))
    
    #获取第1张图片数据
    img=mydataset[0]
    print('imgsize: {}'.format(img.size))
    #显示图片
    img.show('img')

 

 

 

3.transforms

打开transforms.py文件可以看到其中包含的对数据的处理方法:(关于这些方法要使用的时候都可以直接查询)

torchvision官网查看功能:

https://pytorch.org/vision/stable/transforms.html 

transforms.compose使用

#compose中包含一个数组,数组中包含的是对图片数据集进行处理的过程
#比如下面,首先对一张图片进行中心的裁剪,其次将PIL数据类型转换为Tensor数据类型,
#最后是将数据转换为浮点类型
transf=transforms.Compose([
    transforms.CenterCrop(10),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float)
])
print('imgSize: {}'.format(img.size))
print(type(transf(img)))


如果读者不使用.compose的话,也可以使用下面的方法一步一步的进行数据转化:

import os
import torch
from PIL import Image
from torchvision import transforms,datasets

img_path="myDataset/flower_photos/daisy/5547758_eea9edfd54_n.jpg"
#读取图片数据
img=Image.open(img_path)
#显示图片大小
print('imgSize: {}'.format(img.size))
#第一步:对图片数据进行中心裁剪
centerCut=transforms.CenterCrop(100)
img_cut=centerCut(img)
img_cut.show('img_cut')
print('imgCutSize: {}'.format(img_cut.size))

#第二步:将图片数据集转换为Tensor
ToTensor=transforms.ToTensor()
img_ToTensor=ToTensor(img_cut)
print(type(img_ToTensor))

#第三步:将Tensor数据转换为浮点类型
FloatData=transforms.ConvertImageDtype(dtype=torch.float)
img_Float=FloatData(img_ToTensor)
print(type(img_Float))

关于上面一些自己比较常用的一些方法 :但是读者应该注意的是,在Compose中使用这些方法时,对数据的处理先后顺序注意,因为有些方法要传入的是Tensor数据类型,所以将数据转换为Tensor类型方法可能得放在其他方法的前面,注意报错的问题所在。

transform=transforms.Compose([
    transforms.Resize(size=[224,224]),
    transforms.CenterCrop(100),
    transforms.ToTensor(),
    #output[channel] = (input[channel] - mean[channel]) / std[channel],由于图片是三通道的,所以平均值和方差都是分别给出三个值
    transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]),
    #p表示水平翻转的概率
    transforms.RandomHorizontalFlip(p=0.5),
    #垂直翻转
    transforms.RandomVerticalFlip(p=0.5),
    #随机旋转,degrees表示旋转度数,center表示旋转中心坐标,还有其他的参数可以自行选择
    transforms.RandomRotation(degrees=45,center=[50,50],)
])

4.Dataloader对数据进一步处理

官网解释:https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader

以下给出的是一些常见设置参数:

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, num_workers=0)
  • dataset:自己的数据集。
  • batch_size:每一次加载的数据量。
  • shuffle:是否随机打散;当为True时,随机打散,否则默认。
  • num_workers:加载数据所使用的进程数,如果为0,表示默认使用主进程。
import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms,datasets

#数据集的预处理
transform=transforms.Compose([
    transforms.ToTensor()
])

#下载数据集
root='myDataset/CIFAR10'
#将数据集下载之后保存到root路径下(如果数据集已经存在,则不从网上下载,直接从给定的root路径下获取)
train_data=datasets.CIFAR10(root=root,train=True,transform=transform,download=True)
test_data=datasets.CIFAR10(root=root,train=False,transform=transform,download=True)
#获取图片类别
classes=train_data.classes

#加载数据集
train_loader=DataLoader(dataset=train_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
test_loader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)

for data in train_loader:
    imgs,targets=data
    print('imgs: {}'.format(imgs.shape))
    print('target: {}'.format(targets))
    #打印前四张打包的图片类别
    for stop,i in enumerate(targets):
        print('target[{}]---->{}'.format(i,classes[i]))

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用pytorch加载数据集和对数据集进行处理 的相关文章

  • 什么是微服务——微服务架构体系介绍

    Why Microservices 回答这个问题前 xff0c 我们先看下之前大行其道的单体架构 Monolithic Architecture xff0c 对于非专业人士来讲 xff0c 所谓的单体架构 xff0c 其就像一个超大容器 x
  • 微服务架构特征

    一个典型的微服务架构 xff08 MSA xff09 通常包含以下组件 xff1a 客户端 xff1a 微服务架构着眼于识别各种不同的类型的设备 xff0c 以及在此设备上进行的各种管理操作 xff1a 搜索 构建 配置等等身份标识提供者
  • 微服务架构系列——API服务网关

    本章我们简单介绍微服务架构下的API服务网关 xff0c 本章我们将讨论以下话题 xff1a 什么是API服务网关为什么需要API服务网关API服务网关的工作机制 处理横切关注点 当我们在开发设计大型软件应用时 xff0c 我们一般都会采用
  • Java之keytool命令学习

    Java Keytool is a key and certificate management utility It allows users to manage their own public private key pairs an
  • HashMap 与 HashTable的区别

    HashMap 实现了Map接口非线程同步 xff0c 非线程安全不允许重复键键和值均允许为null HashMap lt Interger String gt employeeHashmap 61 new HashMap lt Integ
  • 如何避免敏捷失败?

    很多人都听说敏捷 xff0c 有些人知道敏捷是什么 xff0c 有些人也尝试过敏捷 xff0c 本章中将列举出一些常见的错误敏捷实践 xff0c 如果想要避免敏捷失败 xff0c 建议还是要对照下你所在的敏捷团队中有没有类似的敏捷实践 xf
  • 一个人有文化,到底有多重要?

    关于什么是文化 xff0c 我最最欣赏的回答 xff0c 是作家梁晓声的四句概括 xff1a 根植于内心的修养 xff0c 无需提醒的自觉 xff0c 以约束为前提的自由 xff0c 为别人着想的善良 01 一位叫做 Judy 的空姐 xf
  • MyBatis动态SQL中Map参数处理

    在MyBatis中 xff0c 如果我们需要传递两个参数 xff0c 有一种方式是通过Map作为传入参数 xff0c 在动态SQL中 xff0c 我们需要对传入的Map参数中的值进行判断 xff0c 然后进行动态SQL的条件拼接处理 假设我
  • MyBatis框架下防止SQL注入

    与传统的ORM框架不同 xff0c MyBatis使用XML描述符将对象映射到SQL语句或者存储过程中 xff0c 这种机制可以让我们更大的灵活度通过SQL来操作数据库对象 xff0c 因此 xff0c 我们必须小心这种便利下SQL注入的可
  • 使用android 视频解码mediaCodec碰到的几个问题

    问题1 mediaCodec dequeueInputBuffer一直返回 1 xff0c APP现象 xff1a 视屏卡屏 原因 xff1a 这是因为inputbuffer的内容有误 xff0c 导致无法解码 可通过设延时时间解决 xff
  • 云计算思维导图

    根据近期的云计算学习心得 xff0c 将云计算部分内容制作成思维导图 xff0c 方便于广大云计算学习者作为辅导讲义 xff01 思维导图内容主要包含 xff1a 1 云计算概述 2 云体系结构 3 网络资源 4 存储资源 5 硬件介绍 6
  • 路由器重温——串行链路链路层协议积累

    对于广域网接口来说 xff0c 主要的不同或者说主要的复杂性在于理解不同接口的物理特性以及链路层协议 xff0c 再上层基本都是 IP 协议 xff0c 基本上都是相同的 WAN口中的serial接口主要使用点对点的链路层协议有 xff0c
  • 路由器重温——PPPoE配置管理-2

    四 配置设备作为PPPoE服务器 路由器的PPPoE服务器功能可以配置在物理以太网接口或 PON 接口上 xff0c 也可配置在由 ADSL 接口生成的虚拟以太网接口上 1 配置虚拟模板接口 虚拟模板接口VT和以太网接口或PON接口绑定后
  • Python入门自学进阶——1--装饰器

    理解装饰器 xff0c 先要理解函数和高阶函数 首先要明白 xff0c 函数名就是一个变量 xff0c 如下图 xff0c 定义一个变量名和定义一个函数 xff0c 函数名与变量名是等价的 既然函数名就是一个变量名 xff0c 那么在定义函
  • Python入门自学进阶-Web框架——21、DjangoAdmin项目应用

    客户关系管理 以admin项目为基础 xff0c 扩展自己的项目 一 创建项目 二 配置数据库 xff0c 使用mysql数据库 xff1a 需要安全mysqlclient模块 xff1a pip install mysqlclient D
  • Python入门自学进阶-Web框架——33、瀑布流布局与组合查询

    一 瀑布流 xff0c 是指页面布局中 xff0c 在显示很多图片时 xff0c 图片及文字大小不相同 xff0c 导致页面排版不美观 如上图 xff0c 右边的布局 xff0c 因为第一行第一张图片过长 xff0c 第二行的第一张被挤到第
  • Python入门自学进阶-Web框架——34、富文本编辑器KindEditor、爬虫初步

    KindEditor 是一个轻量级的富文本编辑器 xff0c 应用于浏览器客户端 一 首先是下载 xff1a http kindeditor net down php xff0c 如下图 下载后是 解压缩后 xff1a 红框选中的都可以删除
  • Python入门自学进阶-Web框架——35、网络爬虫使用

    自动从网上抓取信息 xff0c 就是获取相应的网页 xff0c 对网页内容进行抽取整理 xff0c 获取有用的信息 xff0c 保存下来 要实现网上爬取信息 xff0c 关键是模拟浏览器动作 xff0c 实现自动向网址发送请求 xff0c
  • 6、spring的五种类型通知

    spring共提供了五种类型的通知 xff1a 通知类型接口描述Around 环绕通知org aopalliance intercept MethodInterceptor拦截对目标方法调用Before 前置通知org springfram
  • 路由器接口配置与管理——1

    路由器的接口相对于交换机来说最大的特点就是接口类型和配置更为复杂 xff0c 一般吧路由器上的接口分为三大类 xff1a 一类用于局域网的LAN接口 xff0c 一类用于广域网接入 互联的WAN接口 xff0c 最后一类可以应用于LAN组网

随机推荐