Pytorch中常见transform的使用

2023-10-27

本次实验练习了pytorch中数据的读取,Dataset类的使用,以及transform模块的使用。

一、Pytorch简介

PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。

2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。它是一个基于Python的可续计算包,提供两个高级功能:1、具有强大的GPU加速的张量计算(如NumPy)。2、包含自动求导系统的深度神经网络。

二、Pytorch的环境配置

关于Pytorch的环境配置网上有好多教学,这里不做赘述。

三、Dataset类的基本使用

Dataset类:处理数据,提供一种方式挑选数据及其对应的label。

Dataloader类:对Dataset挑选后的数据进行打包,为后面的网络提供不同的数据形式。

1、首先导入Dataset类

from torch.utils.data import Dataset

2、创建一个类,继承Dataset类

class MyData(Dataset):



    def __init__(self, root_dir, label_dir):

        self.root_dir = root_dir

        self.label_dir = label_dir

        # os.path.join的意思是把这两个路径拼接

        # 如root路径是dataset\train,label路径是ants,拼接后的结果是dataset\train\\ants

        self.path = os.path.join(self.root_dir, self.label_dir)

        # os.listdir(path)

        # 作用:传入任意一个path路径,返回的是该路径下所有文件和目录组成的列表;

        self.img_path = os.listdir(self.path)



    # 这个函数作用是获取其中的每一个图片

    def __getitem__(self, idx):

        # idx是图片的索引,img_name是获取图片

        img_name = self.img_path[idx]

        # 把图片的路径也拼接上

        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)

        # 打开图片

        img = Image.open(img_item_path)

        # 需要用到标签

        label = self.label_dir

        # 返回标签和读取的图片

        return img, label



    def __len__(self):

        # 返回有多少张图

        return len(self.img_path)

四、常见transform的使用

首先导入SummaryWriter函数,此函数的作用是将图片在浏览器中显示。

writer = SummaryWriter('logs')
img = Image.open(
'images/220927.png').convert('RGB')

1、ToTensor方法:


这个类可以接受的图像类型为Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

# ToTensor的使用
trans_totensor = transforms.ToTensor()
img_tensor = trans_totensor(img)
writer.add_image(
"TotTensor", img_tensor)

在终端输入

tensorboard --logdir="logs" --port=6007

点击链接进入浏览器输出图像如下

 

这个方法的作用是将图片转换为tensor类型。

2、Normalize方法

归一化类,需要传入均值和标准差。

tran_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
# 传入需要归一化的图片
img_norm = tran_norm(img_tensor)
print(img_norm[0][0][0])
writer.add_image(
'Normalize', img_norm, 2)

输出结果如下

 

3、Resize方法

即改变图片尺寸

trans_resize = transforms.Resize((3, 3))
img_resize = trans_resize(img_tensor)
print(img_resize)
writer.add_image(
'Resize', img_resize, 0)

输出结果如下

 

4、Compose方法

compose()用法:其中的参数需要的是一个列表,列表中的数据类型是transforms,意义是把两个类的方法合并。

trans_resize_2 = transforms.Resize(512)
trans_compose = transforms.Compose([trans_resize
, trans_resize_2])
img_resize_2 = trans_compose(img_tensor)

print(img_resize_2)
writer.add_image(
'resize', img_resize_2, 1)

输出结果如下

 

五、dataset类与transform的结合使用

首先下载数据集,因为仅作练习使用,所以下载较小的CIFAR10数据集。

root是保存的目录,train=True时下载的时训练集,反之下载数据集,将下载的数据集转换为tensor类型

train_set = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=dataset_transformdownload=True)
test_set = torchvision.datasets.CIFAR10(
root='./dataset', train=False, transform=dataset_transform, download=True)

取测试集的前十张图片传入浏览器

writer = SummaryWriter('logs')
for i in range(10):
    img
, target = test_set[i]
    writer.add_image(
'test_set', img, i)
writer.close()

输出结果如下

 

2、dataloader的使用

Dataloader类:对Dataset挑选后的数据进行打包,为后面的网络提供不同的数据形式。

准备的测试数据集

test_data = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor())

batch_size=4的意思是每次从数据集中取出4个数据进行打包

test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)

将打包好的图片在浏览器中显示

step = 0
writer = SummaryWriter('dataloader')
for data in test_loader:
    imgs
, targets = data
    writer.add_images(
'test_data', imgs, step)
    step = step+
1

writer.close()

输出结果如下

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch中常见transform的使用 的相关文章

随机推荐

  • Linux命令:pidof

    pidof命令 查询某个指定服务进程的PID值 每个进程的进程号码 PID 是唯一的 因此可以通过PID来区分不同的进程 执行以下命令查询sudo服务的PID root LAPTOP HJMUH10E home simon pidof su
  • JVM-从熟悉到精通

    JVM 机器语言 一个指令由操作码和操作数组成 方法调用等于一个压栈的过程 栈有 BP寄存器 和 SP寄存器来占用空间 BP gt Base Point 栈基址 栈底 SP gt Stack Point 栈顶 字节序用于规定数据在内存单元如
  • CUDA并行算法系列之FFT快速卷积

    CUDA并行算法系列之FFT快速卷积 卷积定义 在维基百科上 卷积定义为 离散卷积定义为 0 1 2 3 和 0 1 2 的卷积例子如下图所示 Python实现 直接卷积 根据离散卷积的定义 用Python实现 def conv a b N
  • RNN, LSTM, GRU模型结构详解(看这一篇就够了)

    RNN和LSTM讲解超详细的文章 https zhuanlan zhihu com p 32085405 GRU超详解文章 https zhuanlan zhihu com p 32481747
  • jupyter notebook 导出 markdown文件格式

    jupyter notebook 导出 markdown文件格式 原本jupyter notebook 里面自带的可以选择导出为markdown格式 但是下载之后文件总是打不开 只能另寻他法 方法 第一步 安装nbconvert pip i
  • C++类和对象的基本概念

    目录 1 c和c 中struct的区别 2 类的封装 3 类的访问权限 1 c和c 中struct的区别 c语言中结构体中不能存放函数 也就是数据 属性 和行为 方 法 是分离的 c 中结构体中是可以存放函数的 也就是数据 属性 和行为 方
  • Linux文件编程常用函数详解——fcntl()函数

    fcntl 函数 include
  • 智能指针(二):shared_ptr实现原理

    前面讲到auto ptr有个很大的缺陷就是所有权的转移 就是一个对象的内存块只能被一个智能指针对象所拥有 但我们有些时候希望共用那个内存块 于是C 11标准中有了shared ptr这样的智能指针 顾名思义 有个shared表明共享嘛 所以
  • windows升级node版本

    当本地的node版本过低的时候 这就需要升级更高版本来满足开发需求 本文详细教大家如何升级自己需要的node版本 1 官网 下载 Node js 中文网 下载找到需要升级的node版本 下载也默认只有长期支持版本和最新版本 如果满足需求 直
  • 2020 MCM Weekend 2 Problem C,2020美赛C题——完整版题目

    文章目录 Problem C A Wealth of Data Problem Requirements Glossary Data Set Definitions Problem C A Wealth of Data Problem In
  • 测试开发岗需要学习什么样的技能才能满足需求?也许通过阅读各个互联网大厂的JD你会更加清楚

    目录 前言 各大互联网厂关于测试开发的要求 实习 测试开发实习生 测试中心 B站 测试开发实习生 商业技术部 B站 测试开发实习生 直播 B站 测试开发工程师 实习 阿里 游戏测试开发工程师 实习 阿里 测试开发工程师 教育业务 实习 字节
  • 时间序列之指数平滑法(Exponential Smoothing)

    统计中 预测方法除了利用多个影响因素建立回归模型来做预测外 在影响因素复杂 或者是没办法得到相关影响因素的数据信息时 回归模型就无能为力了 如果数据是时间序列上的值 在时间上可能呈现一定的稳态或者规律 利用过去时间区间的值来预测未来值 指数
  • 关于Win2008系统DNS服务器安装配置操作教程

    DNS是因特网的一项核心服务 它作为可以将域名和IP地址相互映射的一个分布式数据库 能够使人更方便的访问互联网 而不用去记住能够被机器直接读取的IP 中文全称 网络协议 地址数串 在win2008系统中要成功安装DNS服务器才能够正常的连接
  • Python工程师的发展前景如何?薪资高吗?5点给你分析齐全

    根据网上的人爆料 2020 互联网大厂校招硕士生的薪资情况 和美团今年的校招信息发布 也是引起一波热潮 许多人看到这些薪资都会感叹一声 那真正处于技术岗位的人员又是另一种看法 同时也激起了许多人想学编程的想法 而目前较为火热的Python也
  • 可视化翻转教学python

    目录 第1关 绘制折线图 第2关 绘制正弦曲线 第3关 绘制指定线型 颜色和标记的正弦曲线 第4关 定义绘制正余弦函数曲线的函数 第5关 绘制坐标轴并设置范围 第1关 绘制折线图 显示绘制结果 plt show 用于显示绘制的结果 无参数
  • 华为OD机试 - 报数问题(Java)

    题目描述 有n个人围成一圈 顺序排号为1 n 从第一个人开始报数 从1到3报数 凡报到3的人退出圈子 问最后留下的是原来第几号的那位 输入描述 输入人数n n lt 1000 输出描述 输出最后留下的是原来第几号 用例 输入 2 输出 2
  • PHP 密码长度至少为8,且必须包含大小写字母/数字/符号任意三者组合

    密码长度至少为8 且必须包含大小写字母 数字 符号任意三者组合 public function rexCheckPassword pwd 12345678aaA 8 20 位 字母 数字 字符 密码必须包含大小写字母 数字 符号任意两者组合
  • 程序员必知的设计模式七大原则

    文章目录 设计模式的目的 1 单一职责原则 1 1 单一职责原则注意事项和细节 2 接口隔离原则 2 1 接口隔离原则例子 3 依赖倒转原则 3 1 什么是依赖 3 2 依赖关系传递的三种方式 1 接口传递 依赖 2 构造方法传递 组合 3
  • 用U深度启动U盘清除系统登入密码

    先添加一块硬盘 修改启动顺序 选择windows密码破解工具 选择选项1 出现了许多硬盘 一个一个去试SAM在那个硬盘 最后发现在硬盘2 出现以下界面 选择第一个用户 按y键保存并退出 在按esc键一直退到以下界面 输入r退出关闭计算机 把
  • Pytorch中常见transform的使用

    本次实验练习了pytorch中数据的读取 Dataset类的使用 以及transform模块的使用 一 Pytorch简介 PyTorch是一个开源的Python机器学习库 基于Torch 用于自然语言处理等应用程序 2017年1月 由Fa