Pytorch学习(二)使用 torchvision

2023-11-18

训练图像分类器

官方教程
我们将按顺序执行以下步骤:
1.使用使用 torchvision
2.定义卷积神经网络
3.定义损耗函数
4.根据培训数据对网络进行训练
5.在测试数据上测试网络
这篇博文为第一步

准备数据集

在训练神经网络前,必须有数据。可以使用以下几个数据提供源。
准备图片数据集
一、CIFAR-10CIFAR-10
二、ImageNetImageNet
三、ImageFolderImageFolder
四、LSUN ClassificationLSUN Classification
五、COCO (Captioning and Detection)COCO

torchvision

为了方便加载以上五种数据库的数据,torchvision 是 PyTorch 中专门用来处理图像的库。使用torchvision就可以轻松实现数据的加载和预处理。
我们以使用CIFAR10为例:
在这里插入图片描述

一. 导入torchvision的库

import torchvision
import torchvision.transforms as transforms   # transforms用于数据预处理

二. 使用datasets.CIFAR10()函数加载数据库

CIFAR10有60000张图片,其中50000张是训练集,10000张是测试集。

#训练集,将相对目录./data下的cifar-10-batches-py文件夹中的全部数据
#(50000张图片作为训练数据)加载到内存中,若download为True时,会自动从网上下载数据并解压
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
download=False, transform=None)

下面简单讲解root、train、download、transform这四个参数
1.root,表示cifar10数据的加载的相对目录
2.train,表示是否加载数据库的训练集,false的时候加载测试集
3.download,表示是否自动下载cifar数据集
4.transform,表示是否需要对数据进行预处理,none为不进行预处理
由于美帝路途遥远,靠命令台进程下载100多M的数据速度很慢,所以我们可以自己去到cifar10的官网上把CIFAR-10 python version下载下来,然后解压为cifar-10-batches-py文件夹,并复制到相对目录./data下。(若设置download=True,则程序会自动从网上下载cifar10数据到相对目录./data下,但这样小伙伴们可能要等一个世纪了),并对训练集进行加载(train=True)
在脚本文件下建一个data文件夹,然后把数据集文件夹丢到里面去就好了,注意cifar-10-batches-py文件夹名字不能自己任意改。
接下来看一下trainset的大小

print len(trainset)
#结果:50000

在这里插入图片描述

三. DataLoader用多进程加速batch data的处理

我们在训练神经网络时,使用的是mini-batch(一次输入多张图片),所以我们在使用一个叫DataLoader的工具为我们将50000张图分成每四张图一分,一共12500份的数据包。

#将训练集的50000张图片划分成12500份,每份4张图,用于mini-batch输入。
#shffule=True在表示不同批次的数据遍历时,打乱顺序(这个需要在训练神经网络时再来讲)。
#num_workers=2表示使用两个子进程来加载数据

import torch
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 
shuffle=False, num_workers=2)

四. 对数据预处理

接下来需要对数据进行预处理,预处理会帮助我们加快神经网络的训练。
预处理用到了transforms函数:

transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

compose函数会将多个transforms包在一起。
1.ToTensor是指把PIL.Image(RGB) 或者numpy.ndarray(H x W x C) 从0到255的值映射到0到1的范围内,并转化成Tensor格式。
2.Normalize(mean,std)是通过下面公式实现数据归一化
transforms.Normalize( mean = (0.5,0.5,0.5), std = (0.5,0.5,0.5) )并不是指将张量的均值和标准差设为0.5,而是做这么一个运算:输入的每个channel做 ( [0, 1] - mean(0.5) )/ std(0.5)= [-1, 1] 的运算,所以这一句的实际结果是将[0,1]的张量归一化到[-1, 1]上。

channel=(channel-mean)/std 

经过上面两个操作,我们的数据中的每个值就变成了[-1,1]的数了。

完整代码

1.先引入库

import torch
import torchvision
import torchvision.transforms as transforms

2.数据预处理,归化为tensor数据,并归一化为[-1, 1]

# torchvision输出的是PILImage,值的范围是[0, 1].
# 我们将其转化为tensor数据,并归一化为[-1, 1]。
transform=transforms.Compose([transforms.ToTensor(),
                             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                             ])

3.加载训练集数据

#训练集,将相对目录./data下的cifar-10-batches-py文件夹中的全部数据
#(50000张图片作为训练数据)加载到内存中,若download为True时,会自动从网上下载数据并解压
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
download=False, transform=transform)

4.划分数据,加速处理

#将训练集的50000张图片划分成12500份,每份4张图,用于mini-batch输入。
#shffule=True在表示不同批次的数据遍历时,打乱顺序。num_workers=2表示使用两个子进程来加载数据
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 
shuffle=False, num_workers=2)

5.展示训练图像,直观感受
对代码中一部分的理解,参考plt.imshow(np.transpose(npimg, (1, 2, 0)))在pytorch中,读入图片并进行显示的方式有两种,见博文。

#添加CIFAR10的标签
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize 非标准的反归一化
    #因为归一话的时候是先减去平均值0.5 ,然后再除以标准偏差0.5那么反归一化就是先乘以0.5,再加0.5。
    npimg = img.numpy()#Image互转化为numpy,将torch.FloatTensor 转换为numpy
    #因为在plt.imshow现在实际输入的是(imagesize,imagesize,channels),
    #而定义中参数img的格式为(channels,imagesize,imagesize),这两者的格式不一致
    #我们需要调用np.transpose函数,即np.transpose(npimg,(1,2,0))
    #将npimg的数据格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),进行格式的转换后方可进行显示。
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s'%classes[labels[j]] for j in range(4)))

输出图片:
在这里插入图片描述
输出标签:
在这里插入图片描述

6.运行时出现的问题
在这里插入图片描述
原因:多进程需要在main函数中运行
解决办法:参考博客
1.加main函数,在main中调用
2.num_workers改为0,单进程加载

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch学习(二)使用 torchvision 的相关文章

随机推荐

  • 4G路由器设置

    总共分四步如下图所示 1 用网线连接电脑 2 给路由器上电 3 设置电脑网络 如图打开电脑网络和共享数据中心选中本地连接双击 弹出如下图所示弹框选择图中ipv4双击 根据下面图片配置ip地址 4 浏览器上输入地址访问路由器进行配置 第一步
  • CV 经典主干网络 (Backbone) 系列: CSP-Darknet53

    CSP Darknet53 0 引言 1 网络结构图 1 1 输入部分 1 2 CSP部分结构 1 3 输出部分 2 代码实现 2 1 代码整体实现 2 2 代码各个阶段实现 3 代码测试 4 结论 0 引言 CSP Darknet53无论
  • Halcon实战记录之二《判断两个直线或者矩形是否相交》

    项目中使用到需要判断两个矩形是否相交 由于我使用Halcon不久 对其算子还不熟悉 不知道是否有现成的算子可以直接实现 如果有 还请各位朋友给留言指出 先谢谢了 我这里用了如下的方法 1 如果两个矩形相交 那么它们中的线段一定会有相交的 我
  • LeetCode 687. 最长同值路径

    题目链接 https leetcode cn problems longest univalue path C 代码如下 Definition for a binary tree node struct TreeNode int val T
  • 优惠券的设计分享

    优惠券是一种常见的促销手段 在形式上给予消费者心理一定的折扣 然后促成订单 本文主要分享关于优惠券的设计 一 引子 促销活动的目的按对象可分为对用户 对产品 对公司 其中对用户的促销目的又可分为三种 拉新 促活 留存 优惠券作为一种常见的促
  • 前端基础知识与常见面试题(九)

    描述 现有n种砝码 重量互不相等 分别为 m1 m2 m3 mn 每种砝码对应的数量为 x1 x2 x3 xn 现在要用这些砝码去称物体的重量 放在同一侧 问能称出多少种不同的重量 注 称重重量包括 0 输入描述 对于每组测试数据 第一行
  • 逆向某联盟RSA登录

    目录 1 抓包分析 2 逆向 1 抓包分析 经典抓包套路 发现载荷password的参数进行了加密 还是如此之长 那就可以猜测是RSA加密了 点击启动器 找到login位置 然后搜索password 发现果然是RSA加密 人家还贴切的给了注
  • 零基础入门STM32编程——GPIO(五)

    系列教程链接 HAL库编程点灯篇https blog csdn net oHaoEr article details 122999523 一 GPIO简介 1 1 概述 GPIO 通用输入输出端口 即芯片的IO管脚 STM32F103系列中
  • 深度学习训练之optimizer优化器(BGD、SGD、MBGD、SGDM、NAG、AdaGrad、AdaDelta、Adam)的最全系统详解

    文章目录 1 BGD 批量梯度下降 2 SGD 随机梯度下降 2 1 SGD导致的Zigzag现象 3 MBGD 小批量梯度下降 3 1 BGD SGD MBGD的比较 4 SGDM 5 NAG 6 AdaGrad Adaptive Gra
  • EndNote在word中进行文献引用的插入时,没有出现编号问题

    转载链接 https blog csdn net qq 32120957 article details 83547621 EndNote 是一个著名的参考文献管理软件 用来创建个人参考文献库 并且可以加入文本 图像 表格和方程式等内容及链
  • 网络编程---TCP/UDP套接字编程原理

    本篇介绍的是Linux下的网络编程 故有些接口是不适用于Windows的 但是具体概念和实现方法是大体一致的 本篇重在讲解原理 具体实现请戳这里 gt UDP套接字编程实现 介绍 网络编程套接字 socket 也是进程间通信的一种方式 但是
  • 浅谈Canvas和SVG的区别

    各位都知道canvas是html5提供的新元素 而svg存在的历史要比canvas久远 已经有十几年了 svg并不是html5专有的标签 Canvas和SVG的区别在哪呢 那我们就看看它们的特点 1 SVG SVG可缩放矢量图形 Scala
  • 基于卷积神经网络的人脸表情识别综述

    基于卷积神经网络的人脸表情识别 摘要 在日常的沟通与交流过程中 运用面部表情可以促使沟通交流变得更加顺畅 因此对于人类而言 进行面部表情的解读也是进行相关沟通交流内容获取的重要程序 随着科学技术的不断发展 人工智能在日常人类交流沟通中 运用
  • Jenkins+Python完整版

    一 简介 一般网站部署的流程 这边是完整流程而不是简化的流程 需求分析 原型设计 开发代码 内网部署 提交测试 确认上线 备份数据 外网更新 最终测试 如果发现外网部署的代码有异常 需要及时回滚 一般是运维来做 功能测试 上线的时间 jen
  • 北京的IT崩盘了么?

    相信今年的互联网行情 大家都有目共睹 身边被各种裁员 劝退的朋友比往年要多了很多 而如今想要找一份还不错的工作 难度也是直线上升 我个人的感受是 这行情就像股价 有起有落 目前处于衰退期 崩盘倒是不至于 网上对此也有很多看法 今天分享一些
  • Qt保存Excel格式数据

    目录 前言 1 下载源码 2 编译源码 3 写Excel数据示例 前言 本文以一个示例介绍了如何使用 libxlsxwriter 开源库保存QTableWidget表格中的数据到Excel文件 libxlsxwriter 是一个C语言库 可
  • nmap操作系统检测_Nmap操作系统检测

    nmap操作系统检测 rps include post 6632 rps include post 6632 One of the most popular feature of nmap is its Operating System d
  • 关于redis 5.0 新数据类型 Stream

    redis 5 0 新特性见 https www oschina net news 100931 redis 5 0 released p 2 对于stream 详细使用和解释见https www zhihu com question 27
  • 宋浩线性代数笔记(五)矩阵的对角化

    本章的知识点难度和重要程度都是线代中当之无愧的T0级 对于各种杂碎的知识点 多做题 复盘才能良好的掌握 良好掌握的关键点在于 所谓的性质A与性质B 是谁推导得谁
  • Pytorch学习(二)使用 torchvision

    Pytorch学习 二 使用 torchvision 训练图像分类器 准备数据集 torchvision 一 导入torchvision的库 二 使用datasets CIFAR10 函数加载数据库 三 DataLoader用多进程加速ba