基于PyTorch的深度学习--CNN项目代码准备-数据集处理(Extract、Transform和Load)

2023-11-20

本篇文章是翻译:https://deeplizard.com网站中的关于Pytorch学习的文章,供学习使用
原文地址为:https://deeplizard.com/learn/video/8n-TGaBZnk4

使用Pytorch进行提取(E)、转换(T)和加载(L)

欢迎回到基于PyTorch的神经网络编程系列课程。在这一部分,我们将编写我们的第一个代码。
我们将使用torchvision和PyTorch的计算机视觉包进行一个简单的提取、转换和加载过程的展示。废话不多说,现在开始。

概述

在这个项目中,我们将遵循四个步骤:

  1. 准备数据
  2. 建立模型
  3. 训练模型
  4. 分析模型的结果

ETL过程(提取、转换和加载)

在这篇文章中,我们从准备数据开始,为了准备数据,我们将会遵循ETL过程。

  • 从数据源中提取数据。
  • 将数据转换成理想的格式。
  • 将数据加载到合适的结构(模型)中。
    ETL过程可以被认为是一个分形过程,因为它可以应用在不同的尺度上。这个过程可以应用在小的尺度上,就像一个单一的程序。或者应用在一个大的尺度上,如在一个企业级别的巨大系统的每个小部分。
    如果你想知道更多关于通用数据科学,请查看数据科学文章 ,文章中有详细解释。
    一旦我们完成了数据的ETL过程,我们就可以进行模型的创建和训练了。PyTorch中有一些内置的包和类,可以很方便的进行ETL过程。

PyTorch导包

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

下面的表格将介绍上面的各个包。

介绍
torch 最高级别的PyTorch包和tensor库。
torch.nn 包含模型和建立神经网络扩展类的子包。
torch.optim 包含SGD和Adam等标准优化操作到子包。
torch.nn.functional 包含建立神经网络的典型操作的函数接口。(损失函数和卷积等。)
torchvision 为计算机视觉提供通用数据集通道、模型结构和图像转换的包。
torchvision.transforms 为图像处理提供普通转换的接口。

导入其他的包

下面导入使用Python所需要的标准数据包。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix
#from plotcm import plot_confusion_matrix

import pdb

torch.set_printoptions(linewidth=120)

pdb是一个编译器,import(导入)一个本地文件,我们将会在以后的文章进行混淆矩阵(confusion matrix)的介绍。最后一行设置PyTorch打印语句的打印选项。

现在我们可以进行数据的准备了。

使用PyTorch进行数据的准备

我们准备数据是遵循下面步骤:

  1. 提取:从数据源中获取Fashion-MNIST数据集。
  2. 转换:将我们的数据转换成tensor类型。
  3. 加载:将我们的数据放在一个对象中,便于我们访问。
    PyTorch提供给我们两个类:
介绍
torch.utils.data.Dataset 代表一个数据集的抽象类。
torch.utils.data.DataLoader 包装一个数据集并提供对底层数据的访问。

抽象类是Python中必须执行的方法,所以我们可以通过创建一个子类来创建一个自定义数据集,该子类用来扩展 ‘数据集类(抽象类)’ 的功能。
使用PyTorch创建一个自定义类,我们通过创建子类来扩展数据集类的请求方法。在做这个的时候,我们的子类可以被传入PyTorch DataLoader对象。
我们将会使用fashion-MNIST数据集(内置在torchvison包中),因为数据集内置在数据包中,所以我们不必对我们的项目准备fashion-MNIST数据包。只需要知道Fashion-MNIST内置数据集正在幕后工作就可以了。

所有的数据集类的子类都必须进行重写,__len__提供一个数据集的大小,__getitem__支持从0len(self)的整数索引。

PyTorch 的Torchvision包

torchvision数据包为我们提供获得资源的通道。我们可以获得一下资源:

  • Datasets(像MNISTFashion-MNIST数据包)
  • Models(模型,VGG16
  • Transforms
  • Utils

计算机视觉

所有的这些资源都和计算机视觉相关。
我们已经在之前的文章中学习了Fashion-MNIST数据集的介绍,通过arXiv文献我们可以知道Fashion-MNIST的作者主要的意图是制作一个可以代替MNIST的数据集。
这个意图使得在使用PyTorch在添加Fashion-MNIST数据集时只需要更改URL地址就可以了。
在这个案例PyTorch案例中,这个FashionMNIST数据集扩展了MNIST数据集并重写了urls。
这里是这个类定义的torchvision资源代码:

class FashionMNIST(MNIST):
    """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = [
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
    ]

现在让我们看看如何利用torchvision

PyTorch 数据集类

为了使用torchvision获取fashionMNIST数据集的实例,我们只需要做如下操作:

train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

以前root的参数是'./data/FashionMNIST',然而,现在因为torchvision的更新而改变了。
以上参数的解释:

参数 描述
root 数据在磁盘上的位置
train 数据集是否是用来进行训练的,True表示数据集用来训练
dowenload 这个数据是否应该被下载
transform 应该在数据集元素上执行的转换的组合。

如果我们想让我们的图像转换成tensors格式,我们可以使用内置函数transforms.ToTensor()。因为数据集被用来进行训练,所以我们将其命名为train_set.
当我们第一次运行这个代码的时候,这个Fashion-MNIST数据集将会被下载。随后在数据下载之前进行系统调用检测。因此,我们不用担心数据集会被下载两次。

DataLoader 类

为我们的训练集创建一个DataLoader,我们只需要做:

train_loader = torch.utils.data.DataLoader(train_set
    ,batch_size=1000
    ,shuffle=True
)

我们只需要将train_set作为一个参数传入。现在我们可以使用这个加载器而不是手动进行我们的测试:

  • batch_size :(在我们的案例中我们使用1000)。
  • shuffle:(参数设置为True)
  • num_workers :(默认值是0,意思是我们将使用主程序)

总结

在ETL的观点中,我们完成了提取和使用torchvision进行转换。当我们创建数据集时:

  1. Extract:从网站中提取未加工的数据。
  2. Transform:将未加工的(原始)图像转换成tensor格式。
  3. Load:这个train_set包装了数据加载器,并给我们提供了通往底层数据的通道。
    现在,我们对PyTorch提供的torchvision有了很好的理解,并且也知道了如何使用数据集和数据加载器进行ETL操作。
    在在下一篇文章中,我们将会看到我们可以进行#¥%¥@%¥等一些操作。期待与你的下一次见面。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于PyTorch的深度学习--CNN项目代码准备-数据集处理(Extract、Transform和Load) 的相关文章

  • 我可以使用逻辑索引或索引列表对张量进行切片吗?

    我正在尝试使用列上的逻辑索引对 PyTorch 张量进行切片 我想要与索引向量中的 1 值相对应的列 切片和逻辑索引都是可能的 但是它们可以一起吗 如果是这样 怎么办 我的尝试不断抛出无用的错误 类型错误 使用 ByteTensor 类型的
  • 如何避免 PyTorch 中的“CUDA 内存不足”

    我认为对于 GPU 内存较低的 PyTorch 用户来说 这是一个非常常见的消息 RuntimeError CUDA out of memory Tried to allocate X MiB GPU X X GiB total capac
  • 为什么我在这里遇到被零除的错误?

    所以我正在关注这个文档中的教程 https pytorch org tutorials beginner data loading tutorial html在自定义数据集上 我使用的是 MNIST 数据集 而不是教程中的奇特数据集 这是D
  • 在pytorch张量中过滤数据

    我有一个张量X like 0 1 0 5 1 0 0 1 2 0 我想实现一个名为的函数filter positive 它可以将正数据过滤成新的张量并返回原始张量的索引 例如 new tensor index filter positive
  • 一次热编码期间出现 RunTimeError

    我有一个数据集 其中类值以 1 步从 2 到 2 i e 2 1 0 1 2 其中 9 标识未标记的数据 使用一种热编码 self one hot encode labels 我收到以下错误 RuntimeError index 1 is
  • PyTorch:如何检查训练期间某些权重是否没有改变?

    如何检查 PyTorch 训练期间某些权重是否未更改 据我了解 一种选择可以是在某些时期转储模型权重 并检查它们是否通过迭代权重进行更改 但也许有一些更简单的方法 有两种方法可以解决这个问题 First for name param in
  • 从打包序列中获取每个序列的最后一项

    我试图通过 GRU 放置打包和填充的序列 并检索每个序列最后一项的输出 当然我的意思不是 1项目 但实际上是最后一个 未填充的项目 我们预先知道序列的长度 因此应该很容易为每个序列提取length 1 item 我尝试了以下方法 impor
  • 下载变压器模型以供离线使用

    我有一个训练有素的 Transformer NER 模型 我想在未连接到互联网的机器上使用它 加载此类模型时 当前会将缓存文件下载到 cache 文件夹 要离线加载并运行模型 需要将 cache 文件夹中的文件复制到离线机器上 然而 这些文
  • Pytorch ValueError:优化器得到一个空参数列表

    当尝试创建神经网络并使用 Pytorch 对其进行优化时 我得到了 ValueError 优化器得到一个空参数列表 这是代码 import torch nn as nn import torch nn functional as F fro
  • 如何使用pytorch构建多任务DNN,例如超过100个任务?

    下面是使用 pytorch 为两个回归任务构建 DNN 的示例代码 这forward函数返回两个输出 x1 x2 用于大量回归 分类任务的网络怎么样 例如 100 或 1000 个输出 对所有输出 例如 x1 x2 x100 进行硬编码绝对
  • Pytorch GPU 使用率低

    我正在尝试 pytorch 的例子https pytorch org tutorials beginner blitz cifar10 tutorial html https pytorch org tutorials beginner b
  • PyTorch 中的连接张量

    我有一个张量叫做data形状的 128 4 150 150 其中 128 是批量大小 4 是通道数 最后 2 个维度是高度和宽度 我有另一个张量叫做fake形状的 128 1 150 150 我想放弃最后一个list array从第 2 维
  • 如何在 PyTorch 中对子集使用不同的数据增强

    如何针对不同的情况使用不同的数据增强 转换 Subset在 PyTorch 中吗 例如 train test torch utils data random split dataset 80000 2000 train and test将具
  • 如何计算cifar10数据的平均值和标准差

    Pytorch 使用以下值作为 cifar10 数据的平均值和标准差 变换 Normalize 0 5 0 5 0 5 0 5 0 5 0 5 我需要理解计算背后的概念 因为这些数据是 3 通道图像 我不明白什么是相加的 什么是除什么的等等
  • 将 Pytorch LSTM 的状态参数转换为 Keras LSTM

    我试图将现有的经过训练的 PyTorch 模型移植到 Keras 中 在移植过程中 我陷入了LSTM层 LSTM 网络的 Keras 实现似乎具有三种状态类型的状态矩阵 而 Pytorch 实现则具有四种状态矩阵 例如 对于hidden l
  • 使用 PyTorch 分布式 NCCL 连接失败

    我正在尝试使用 torch distributed 将 PyTorch 张量从一台机器发送到另一台机器 dist init process group 函数正常工作 但是 dist broadcast 函数中出现连接失败 这是我在节点 0
  • Pytorch RuntimeError:“host_softmax”未针对“torch.cuda.LongTensor”实现

    我正在使用 pytorch 来训练模型 但是在计算交叉熵损失时我遇到了运行时错误 Traceback most recent call last File deparser py line 402 in
  • 在requirements.txt中包含.whl安装

    如何将其包含在requirements txt 文件中 对于Linux pip install http download pytorch org whl cu75 torch 0 1 12 post2 cp27 none linux x8
  • PyTorch DataLoader 对并行运行的批次使用相同的随机种子

    有一个bug https tanelp github io posts a bug that plagues thousands of open source ml projects 在 PyTorch Numpy 中 当并行加载批次时Da
  • 尝试将 cuda 与 pytorch 一起使用时出现运行时错误 999

    我为我的 Geforce 2080 ti 安装了 Cuda 10 1 和最新的 Nvidia 驱动程序 我尝试运行一个基本脚本来测试 pytorch 是否正常工作 但出现以下错误 RuntimeError cuda runtime erro

随机推荐

  • Oracle 实现select(查询)的结果集随机顺序展示

    在一些需求中会要求打乱结果集顺序随机展示 Oracle的实现方式如下 select from table order by dbms random value 这种用法没有参数 会返回一个具有38位精度的数值 范围从0 0到1 0 但不包括
  • stm32实用篇5:HAL库 DHT11 驱动

    DHT11是很常用的温湿度传感器 时序也比较简单 如下所示 直接给出HAL库的驱动 1 微秒级延时函数 HAL库并没有直接的微秒级延时函数 下面是自己实现的微秒堵塞延时函数 使用定时器TIM3 brief 微秒级延时 void bsp de
  • 格式化时间用了YYYY-MM-dd,元旦当天老板喊我回去改Bug!

    视频福利推荐 2T免费学习视频 内含SSM Spring全家桶 微服务 MySQL MyCat 集群 分布式 中间件 Linux 网络 多线程 Jenkins Nexus Docker ELK等等免费学习视频 持续更新 往期热门文章 1 往
  • Linux面试题1

    一 取出 etc passwd文件中shell出现的次数 问题 下面是一个 etc passwd文件的部分内容 题目要求取出shell并统计次数 shell是指后面的 bin bash sbin nologin等 如下面 bin bash出
  • 主线剧情0.0-Linux学习资源大综合

    Linux 学习资源大综合 对收集到的比较丰富的 Linux 学习相关的资料进行整理 注 如果链接挂了请告诉我 如果链接里的内容被删了那么直接搜文章名字试试也许会搜出来很多转载的 备份 注 在 Github 上的原版文章日后可能会更新 在其
  • 37: 合并区间

    题目 以数组 intervals 表示若干个区间的集合 其中单个区间为 intervals i starti endi 请你合并所有重叠的区间 并返回 一个不重叠的区间数组 该数组需恰好覆盖输入中的所有区间 思路 这道题我的思路完全正确 先
  • 杰理之提示音的使用【篇】

    打开 SDK 对应的 cpu brxx tools ACxxxN config tool 进入配置工具入口 gt 选择编译前配置 工具 gt 提示音配置
  • 经典面试题之new和malloc的区别

    new和malloc的区别是C C 一道经典的面试题 我也遇到过几次 回答的都不是很好 今天特意整理了一下 0 属性 new delete是C 关键字 需要编译器支持 malloc free是库函数 需要头文件支持 1 参数 使用new操作
  • python3 ACM 输入输出

    Python的输入是字符串 所有需要自己转化为对应的类型 strip去掉左右两端的空白符 返回str slipt把字符串按空白符拆开 返回 str map把list里面的值映射到指定类型 返回 type 有多组输入数据 但没有具体的告诉你有
  • 如何在JavaScript中实现链式调用(chaining)?

    聚沙成塔 每天进步一点点 专栏简介 JavaScript中的链式调用 示例 写在最后 专栏简介 前端入门之旅 探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎来到前端入门之旅 这个专栏是为那些对We
  • 程序员思维模式 - 主调试循环

    文章目录 主调试循环 验证在图层中进行 优化循环时间 为什么快速循环更好 短循环时间是通用的吗 一些综合测试是必要的 复杂性是否会导致测试验证循环 救援的暂存环境 结论 仅通过测试进行验证基本上是在仪器上驾驶飞机 而不是能够向外看挡风玻璃
  • 文心千帆为你而来

    1 前言 3月16号百度率先发布了国内第一个人工智能大语言模型 文心一言 文心一言的发布在业界引起了不小的震动 而文心一言的企业服务则由文心千帆大模型平台提供 文心千帆大模型平台是百度智能云打造出来的一站式大模型开发与应用平台 提供包括文心
  • 工业表面缺陷检测数据集汇总

    1 数据集名称 NEU CLS 应用场景 钢材表面 链接 http faculty neu edu cn songkechen zh CN zhym 263269 list 2 数据集名称 elpv dataset 应用场景 太阳能板 链接
  • 什么是Servlet容器?

    在本文中 我写了一些关于Web服务器 Servlet容器以及它与JVM的关系的基本概念 我想表达的是 Servlet容器也仅仅不过是一个Java程序 1 什么是Web服务器 想要知道什么是Servlet容器 我们首先要知道什么是Web服务器
  • 整理了27个Python人工智能库,建议收藏!

    来源丨网络 大家好 我是阳哥 为了大家能够对人工智能常用的 Python 库有一个初步的了解 以选择能够满足自己需求的库进行学习 对目前较为常见的人工智能库进行简要全面的介绍 1 Numpy NumPy Numerical Python 是
  • uni-app打包ios应用后,屏幕无法占满,上下出现黑框

    软件打包后是成功的 功能也都正常 就是打开软件后上下都出现了黑框 整个软件变小了 5s的屏结果运行的是4s的效果 就像ipad运行了iphone软件一样的那种感觉 这是由于ios缺少启动图引起的 勾选通用启动界面即可 在manifest j
  • java工作记录问题总结

    1 注解等同于 controller Component 2 定时器立即执行一次 每小时执行一次 Async 异步 Scheduled fixedRate 1000 60 60 3 double数据量过大时使用 BigDecimal dou
  • windows下安装python的pip指令

    windows下安装python的pip指令 安装pip前 确定Windows下有python和easy install安装包 确定windows系统中有python环境 并将python解释器配置到系统的环境变量中 1 环境变量中添加py
  • 按照公式,将经纬度转为椭球

    目前墨卡托投影的纹理坐标已经绑定 现在转为椭球体 将经纬度中按照公式直接转为椭球上的xyz 即可 也可以参照三维引擎设计创建椭球
  • 基于PyTorch的深度学习--CNN项目代码准备-数据集处理(Extract、Transform和Load)

    本篇文章是翻译 https deeplizard com网站中的关于Pytorch学习的文章 供学习使用 原文地址为 https deeplizard com learn video 8n TGaBZnk4 使用Pytorch进行提取 E