pytorch实战-图像分类(一)(数据预处理)

2023-11-18

目录

1.导入各种库

2.数据预处理

2.1数据读取

2.2图像增强

3.构建数据网络 

3.1网络构建

3.2读取标签对应的名字

4.展示数据

4.1数据转换

4.2画图

5.模型训练


1.导入各种库

上代码:

import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
#pip install torchvision
from torchvision import transforms, models, datasets
#https://pytorch.org/docs/stable/torchvision/index.html
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

2.数据预处理

2.1数据读取

先看以下训练集和验证集存放的位置

 上代码

data_dir = './flower_data/'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

2.2图像增强

目的:我们所收集准备训练的数据都是很可贵的,数据越多成本也就越高,所以希望将有限的数据集最大化利用,这就时图像增强的目的。

定义:如下图小灰猫,进行翻转操作,小黄猫,进行不同角度的旋转操作,这样实现了一图多用的效果,在原数据的基础上,将数据集翻了几倍。比方说你现在有一个1w的数据集,经过数据增强,可以完成10w的数据集。

 上代码

data_transforms = {
    'train': transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
        transforms.CenterCrop(224),#从中心开始裁剪(224×224),因为训练集收集的图大小可能不同,但神经网络需要同样大小的输入.
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率,p=0.5就是说,有50%概率执行该操作。
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
        transforms.ToTensor(), #将数据转化成tensor格式输入
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#因为本例是要用别人的模型训练,所以要参考别人例子中提供的均值,标准差,对自己的的训练集进行标准化操作。
    ]),
    'valid': transforms.Compose([transforms.Resize(256), #验证集不需要做数据增强,其他处理方法和train一样。
        transforms.CenterCrop(224), #验证集数据裁剪成和训练集一样,才能对比
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

3.构建数据网络 

3.1网络构建

batch_size = 8

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']} # 构建分类任务数据集,注意不同任务数据集构建方式不同。
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']} # 按照batch_size = 8大小加载数据。
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} # 看一下数据的数量,该例'train': 6552, 'valid': 818
class_names = image_datasets['train'].classes

3.2读取标签对应的名字

网络最后的输出是一个代表类别的数值,比方说1,2,3,但我们希望看到这个数值对应的类别,所以json存这些信息,比方说{'1': 'pink primrose'}。

with open('cat_to_name.json', 'r') as f:
    cat_to_name = json.load(f) 

4.展示数据

4.1数据转换

注意:进行训练时需要tensor格式的数据,所以展示的时候tensor的数据需要转换成numpy的格式,而且还需要还原回标准化的结果。

def im_convert(tensor): #im_convert转化函数
    """ 展示数据"""
    
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    return image

4.2画图

fig=plt.figure(figsize=(20, 12))
columns = 4
rows = 2

dataiter = iter(dataloaders['valid'])
inputs, classes = dataiter.next()

for idx in range (columns*rows):
    ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    plt.imshow(im_convert(inputs[idx]))
plt.show()

5.模型训练

下接该文:pytorch实战-图像分类(二)(模型训练及验证)(基于迁移学习(理解+代码))

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch实战-图像分类(一)(数据预处理) 的相关文章

随机推荐

  • 动漫数据推荐系统

    Simple TfidfVectorizer and CountVectorizer recommendation system for beginner 简单的TfidfVectorizer和CountVectorizer推荐系统 适用于
  • stm32单片机之外部脉冲捕获例程

    stm32单片机之外部脉冲捕获例程 定时器通道1来捕获外部脉冲 并且当脉冲到来时 通过HAL库的回调函数来处理这个事件 include stm32f4xx hal h 定义一个TIM HandleTypeDef结构体 TIM HandleT
  • 字符串长度检查

  • docker安装mysql 及 ls: cannot access ‘/docker-entrypoint-initdb.d/‘: Operation not permitted问题解决

    目录 查看本地镜像 搜索可用mysql 拉取最新版本 运行镜像 查看进程是否正常 问题解决 查看本地镜像 查看本地是否已经有mysql镜像了 docker images grep mysql 正常此步骤不会有返回结果 搜索可用mysql d
  • kaggle数据挖掘竞赛初步--Titanic<随机森林&特征重要性>

    完整代码 https github com cindycindyhi kaggle Titanic 特征工程系列 Titanic系列之原始数据分析和数据处理 Titanic系列之数据变换 Titanic系列之派生属性 维归约 之前的三篇博文
  • 模式识别学习笔记之一:模式识别的步骤及相关概念

    1 信息获取 2 预处理 对获取信号进行规范化等各种处理 3 特征提取与选择 将识别样本构造成便于比较 分析的描述量即特征向量 4 分类器设计 由训练过程将训练样本提供的信息变为判别事物的判别函数 5 分类决策 对样本特征分量按判别函数的计
  • 学习二叉树必须要了解的各种遍历方式及节点统计

    哈喽 大家好 我是小林 今天给大家分享一下对二叉树的一些常规操作 愿我们都能保持一颗向上的心 目录 一 前序遍历 二 中序遍历 三 后序遍历 四 统计节点个数 五 统计叶子节点个数 六 第K层的节点个数 七 二叉树的深度 八 查找值为x的节
  • bash 刷题leetcode

    题目一 给定一个文本文件 file txt 请只打印这个文件中的第十行 示例 假设 file txt 有如下内容 Line 1 Line 2 Line 3 Line 4 Line 5 Line 6 Line 7 Line 8 Line 9
  • Revit更改用户选择

    private void ChangeSelection Document document UIDocument uidoc new UIDocument document Autodesk Revit UI Selection SelE
  • 2014 奇虎360 笔试主观题

    1 在审计某一开源项目的代码时 假设有下面一个foo 子函数的实现 从安全的角度看 会存在安全漏洞吗 有的话 请 1 描述漏洞细节 2 说明可以利用的方法 3 还有该怎么修补漏洞 没有的话 也请说明为什么 int foo void func
  • QT.setStyleSheet()用法

    1 基本用法 textViewer gt setStyleSheet background color 00FF00 背景颜色 color FF0000 前景色 color rgb 255 0 0 color rgbd 255 0 0 0
  • selenium+java实现web自动化例子

    简单记录 有不正确的地方请指出 selenium java可以实现对web页面的自动化控制 在公司内部比较稳定 页面迭代较少的后台web系统使用时非常有效 web自动化收益最大化的情况 1 多更新于后端 前端页面迭代较少 2 在日常迭代中页
  • C++vector容器

    vector容器被称为动态数组 也被称为向量 它与array容器的区别是 array是静态数组 动态扩展 并不是在原空间之后续接新空间 而是找更大的内存空间 然后将原数据拷贝新空间 释放原空间 at 函数 返回对矢量中指定位置的元素的引用
  • 第零章 内核网络相关配置选项--基于Linux 3.10

    Kconfig选项 packet protocol 被直接和网络设备通信的应用程序使用 其没有使用内核的其它协议 像tcpdump支持需要使能该选项 af packet lt gt Packet socket 支持PF PACKET套接字
  • kubeadm构建(Calico+Dashboard+Containerd)

    文章目录 前言 一 环境 二 部署容器网络 CNI master操作 1 下载yamll 2 修改yaml 3 部署 三 部署 Dashboard 1 下载yaml 2 修改yaml 3 部署 4 创建管理员 四 切换容器引擎为Contai
  • 区块链开发之Solidity编程基础(一)

    Solidy是当前编写智能合约的主流语言 概要 sol文件结构 编译开发 引入其他文件 注释 代码注释 文档注释 合约 状态变量 类型 值类型 1 布尔类型 2 整型 3 地址 4 定长字节数组 5 有理数和整型字面量 6 枚举类型 7 函
  • controller与servlet的区别

    理解1 你可以理解为 Spring MVC是基于servlet的 它有一个DispatherServlet 然后它负责处理请求 并且调用了你的controller 打一个比方 web网站是应用程序么 你可以说浏览器是一个应用程序 而web网
  • ElementUi tab组件切换导致echarts宽度变窄问题

    解决tab组件变成100px的问题 使用echarts实例自带的resize 方法
  • 大话数据结构:栈与队列(1)

    栈 限定仅在表尾进行插入和删除操作的线性表 栈顶 允许插入和删除的一端 栈底 不允许插入和删除的一端 空栈 不含任何数据元素的栈 后进先出的线性表 LIFO结构 进栈 栈的插入 出栈 栈的删除 元素数量多 出栈的变化会更多 栈的抽象数据类型
  • pytorch实战-图像分类(一)(数据预处理)

    目录 1 导入各种库 2 数据预处理 2 1数据读取 2 2图像增强 3 构建数据网络 3 1网络构建 3 2读取标签对应的名字 4 展示数据 4 1数据转换 4 2画图 5 模型训练 1 导入各种库 上代码 import os impor