PyTorch-10 自定义数据集实战(Load data自定义数据集、Build model创建一个模型、Train and Test、Transfer Learning迁移学习)

2023-10-30

PyTorch-10 自定义数据集实战(Load data自定义数据集、Build model创建一个模型、Train and Test、Transfer Learning迁移学习)

我们以Pokemon Dataset作为自定义数据集:数据集下载地址https://download.csdn.net/download/m0_37755995/85686744
主要以下面5类小精灵。
在这里插入图片描述
查看一下分别有多少张图片,以及splitting划分train和test的比例:
我们并不是每一类的60%做training,而是整体234+239+223+238+234的60%做training。
60%是做training的,剩余的20%做validation,20%做test。
还需要注意的是,如果test的样本量比较少的话,去做test的时候,性能波动是比较大的,因此我们可以有目的性的增大validation和test的样本量。
在这里插入图片描述
这个实战案例主要分为四大步骤:
1、Load data 数据集加载(这里我们是自定义的数据集)
2、Build model 创建一个模型,这个模型会基于之前的ResNet进行少量的修改。
3、Train and Test 进行完整的训练training和测试validation
4、Transfer Learning 由于数据集比较小,容易出现overfitting的情况,这一步骤就是在解决这个情况(迁移学习)。

Step1.Load data数据集的加载工作

如何完成自定义数据集的加载工作呢?
文件名:pokemon.py
分为3小步:
第1步、Inherit from torch.utils.data.Dataset 继承一个通用的母类,和之前的所有网络层都继承于nn.Module这个类一样,这里的数据集也是继承于torch.utils.data.Dataset这个母类。

此外我们还需要自行实现下面两个函数:
第2步、len 这个函数代表数据集总体样本的数量,返回一个整型数字。
第3步、getitem 实现这个接口,通过这个接口我们就可以返回一个x样本。

我们先查看一个最简单的自定义数据的示例:(读取数字的数据集)
下面的就是一个三步完成读取数字数据集的过程,在现实中可能比这个要复杂些,但是步骤就这三步:

import  numpy as np
import  torch
import  torch.nn as nn
import  torch.optim as optim
from    matplotlib import pyplot as plt
from torch import nn, optim
from torch.utils.data import Dataset

class NumberDataset(Dataset): #第一步:首先继承自Dataset这个基本母类
    #初始化函数中:
    #会先将数据存储起来,比如说创建的1-1000的list,并且将其保存到类的一个成员变量上面。
    def __init__(self,training = True): #这里有一个参数training(布尔值),用来表示这个数据集是否是从training数据集中读取还是test数据集读取。
        #根据参数training是True还是False来判断是在哪个范围内取值。
        #train就从train的数据集读取,test就读取test的数据。
        #如果参数training为True,就从1-1000来sample数据。
        if training:
            self.samples = list(range(1,1001))
        #如果参数training为False(即test),就从1001-1500做这个sample。
        else:
            self.samples = list(range(1001,1501))

    #第二步:完成__len__这个函数:
    #这里是表示数据集x总体的数量,这个值就是来源于成员变量sample这个载体的长度。
    def __len__(self):
        return len(self.samples) #将成员变量的长度返回,告知这个数据集总共有多少个。
    #这个__len__函数有什么用呢:如果这个sample总体数量是100,因此只能迭代100次,迭代100次之后,这个数据集迭代一遍的工作就完成了。
    #同时,每一次得到的样本给以给一个index,表示当前取的是哪一个位置的样本,这个index不是自己决定的,这里的dataset只能做一个迭代的工作,不能去调整索引,因此这个index自动从0-999产生。
    #这个index的最大值是来源于这个数据集总体的大小。如果数据有1000的话,那么这个index就是从0-999。

    #第三步:完成__getitem__这个接口
    #index获得了之后,就需要实现返回当前index下的样本值x。
    #这里所有样本都是采样于self.samples,样本的idx,就返回当前idx下的样本值。
    #这里idx就是样本的长度len。
    def __getitem__(self, idx): #参数index
        return self.samples[idx]

数据的预处理工作

1、Image Resize
我们所采集的图片大小可能是任意大小的,但是我们深度学习所输入的图片大小往往是固定的,是标准的正方形图片。我们神经网络接收的是固定的size的图片,因此在将图片送入new network之前,需要确保图片的size是满足new network要求的;如果不满足,需要resize。

比如说,我们采用的模型是ResNet18,其接收图片的大小就是224224,我们通过resize将图片调整到224224这个大小。

2、Data Argumentation
Resize完成后,进行数据增强工作。这一步就是用来增加数据集的规模。

比如说随机的选择,裁切等

3、Normalize
这里的Normalize是把原来0-255之间分布的图片像素值,将其scale到0附进,比如说-1-1或则-0.5-0.5。通过将值scale到0的周围,使得training时更加稳定,更容易收敛。

4、ToTensor
将numpy或image的数据类型转为Tensor类型。

开始实战:

1、简单查看一下数据集
发现图片大小都是不规则的,而且有的图片是png,有的图片是jpg。
在这里插入图片描述
一个根文件夹下面包含5个子文件夹,每个字文件夹的命名就是分类的label的命名,使用这种结构(一个跟文件夹,和若干label命名的子文件夹)可以方便的通过pytorch用一行代码进行读取,不需要我们再去写一个自定义数据集读取的逻辑:
在这里插入图片描述

第一部分:我们先做一个自定义数据集的读取逻辑,之后再用另一个方法load进来,并实现一个样本的可视化:

一、首先完成第一部分:自定义数据集类的初始化函数init

这里我们创建一个用于字典,给每一个分类做一个映射,即当前的皮卡丘、妙蛙种子等这个string类型所对应的label是多少,这个是需要我们人为进行编码的。

import torch
import os, glob
import random, csv

# 所有自定义数据集的一个母类
from torch.utils.data import Dataset, DataLoader
#常用的图片变换器
from torchvision import transforms
#从图片读取出数据
from PIL import Image

# 自定义数据集的类,继承自Dataset
class Pokemon(Dataset):
    # 初始化函数init
    # 第一个参数root:总的图片所在的位置,可以是任意的位置,我们的图片可以放在任意的位置,我们这里就存储在当前目录pokemon文件夹下。
    # 第二个参数resize:图片输出的size,是由这个参数所进行设定。
    # 第三个参数mode:这里我们需要做train、validation以及test,对应这三种数据结构,因此我们用一个list[0,1,2]来代表是哪个模式。
    def __init__(self, root, resize, mode):
        # 先调用母类的初始化函数:
        super(Pokemon, self).__init__()
        #1、首先我们将这个参数保存下来
        self.root = root
        self.resize = resize

        #2、给每一个分类做一个映射,即当前的皮卡丘、妙蛙种子等这个string类型所对应的label是多少,这个是需要我们人为进行编码的。
        self.name2label = {} #用字典来表示映射关系
        #通过循环方式,将root路径下的文件夹名进行编码
        for name in sorted(os.listdir(os.path.join(root))):
            #过滤掉非文件夹:如果不是dir,就过滤掉,此外我们还通过sorted排序的方法,将键值对关系固定下来
            if not os.path.isdir(os.path.join(root,name)):
                continue
            #文件名做key,当前name2label的长度做value
            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)


    # 完成两个自定义的逻辑
    # 1、样本的总体数量(图片总体数量),返回的是一个数字,总体图片大概有1168张,60%用于training,因此返回6-7百张图片
    def __len__(self):
        pass

    # 2、用于返回当前index上面元素的值,这里是返回两个数据:
    # 需要返回当前image的data,以及image所对应的label[0,1,2,3,4]
    def __getitem__(self, idx):
        pass


#创建一个调试函数:
def main():
    db = Pokemon('pokemon',224,'train')


if __name__ == '__main__':
    main()

这块的核心就是保持编码的映射关系:
在这里插入图片描述
在这里插入图片描述

二、创建一个csv,用于写入图片全路径和对应的标签label

import torch
import os, glob
import random, csv

# 所有自定义数据集的一个母类
from torch.utils.data import Dataset, DataLoader
#常用的图片变换器
from torchvision import transforms
#从图片读取出数据
from PIL import Image

# 自定义数据集的类,继承自Dataset
class Pokemon(Dataset):
    # 一、初始化函数init
    # 第一个参数root:总的图片所在的位置,可以是任意的位置,我们的图片可以放在任意的位置,我们这里就存储在当前目录pokemon文件夹下。
    # 第二个参数resize:图片输出的size,是由这个参数所进行设定。
    # 第三个参数mode:这里我们需要做train、validation以及test,对应这三种数据结构,因此我们用一个list[0,1,2]来代表是哪个模式。
    def __init__(self, root, resize, mode):
        # 先调用母类的初始化函数:
        super(Pokemon, self).__init__()
        #1、首先我们将这个参数保存下来
        self.root = root
        self.resize = resize

        #2、给每一个分类做一个映射,即当前的皮卡丘、妙蛙种子等这个string类型所对应的label是多少,这个是需要我们人为进行编码的。
        self.name2label = {} #用字典来表示映射关系
        
        #通过循环方式,将root路径下的文件夹名进行编码
        for name in sorted(os.listdir(os.path.join(root))):
            #过滤掉非文件夹:如果不是dir,就过滤掉,此外我们还通过sorted排序的方法,将键值对关系固定下来
            if not os.path.isdir(os.path.join(root,name)):
                continue
            #文件名做key,当前name2label的长度做value
            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)

        #image, label
        self.load_csv('images.csv')
    
    #二、创建一个csv,用于保存图片全路径和对应的标签label
    #这个函数接受一个参数filename
    #这个函数中需要将所有图片都load进来
    def load_csv(self, filename):
        
        images = []
        for name in self.name2label.keys():
            #类别信息我们可以使用路径来判断
            #比如:'pokemon\\mewtwo\\00001.png'
            #上面路径的mewtwo就是类别
            images += glob.glob(os.path.join(self.root, name, '*.png'))
            images += glob.glob(os.path.join(self.root, name, '*.jpg'))
            images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

        #1167张, 'pokemon\\bulbasaur\\00000000.png'
        print(len(images), images)
        
        #将images顺序打乱
        random.shuffle(images)
        
        #打开这个文件
        with open(os.path.join(self.root,filename),mode='w',newline='') as f:
            #新建writer,获得csv这个文件对象
            writer = csv.writer(f)
            for img in images: #获得每行信息'pokemon\\bulbasaur\\00000000.png'
                #通过分割符,将每行信息的内容分割开,取导数第二个,类型
                name = img.split(os.sep)[-2]
                
                #通过获取的类型名来获取label
                label = self.name2label[name]
                
                #将这个label信息写到csv中
                #csv是以逗号作为分割的
                #形式为:'pokemon\\bulbasaur\\00000000.png', 0
                writer.writerow([img,label])
            print('writen into csv file:',filename)

    # 三、完成两个自定义的逻辑
    # 1、样本的总体数量(图片总体数量),返回的是一个数字,总体图片大概有1168张,60%用于training,因此返回6-7百张图片
    def __len__(self):
        pass

    # 2、用于返回当前index上面元素的值,这里是返回两个数据:
    # 需要返回当前image的data,以及image所对应的label[0,1,2,3,4]
    def __getitem__(self, idx):
        pass


#创建一个调试函数:
def main():
    db = Pokemon('pokemon',224,'train')


if __name__ == '__main__':
    main()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

三、这里需要注意的是,在上面写入csv的基础上,编写一个读取这个csv的方法。

如果这个csv文件以及存在,我们只需读取出来即可,因此增加一个判断,如果文件不存在,就需要创建这个csv,创建好后直接读取即可:

import torch
import os, glob
import random, csv

# 所有自定义数据集的一个母类
from torch.utils.data import Dataset, DataLoader
#常用的图片变换器
from torchvision import transforms
#从图片读取出数据
from PIL import Image

# 自定义数据集的类,继承自Dataset
class Pokemon(Dataset):
    # 一、初始化函数init
    # 第一个参数root:总的图片所在的位置,可以是任意的位置,我们的图片可以放在任意的位置,我们这里就存储在当前目录pokemon文件夹下。
    # 第二个参数resize:图片输出的size,是由这个参数所进行设定。
    # 第三个参数mode:这里我们需要做train、validation以及test,对应这三种数据结构,因此我们用一个list[0,1,2]来代表是哪个模式。
    def __init__(self, root, resize, mode):
        # 先调用母类的初始化函数:
        super(Pokemon, self).__init__()
        #1、首先我们将这个参数保存下来
        self.root = root
        self.resize = resize

        #2、给每一个分类做一个映射,即当前的皮卡丘、妙蛙种子等这个string类型所对应的label是多少,这个是需要我们人为进行编码的。
        self.name2label = {} #用字典来表示映射关系

        #通过循环方式,将root路径下的文件夹名进行编码
        for name in sorted(os.listdir(os.path.join(root))):
            #过滤掉非文件夹:如果不是dir,就过滤掉,此外我们还通过sorted排序的方法,将键值对关系固定下来
            if not os.path.isdir(os.path.join(root,name)):
                continue
            #文件名做key,当前name2label的长度做value
            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)

        #image, label
        self.images, self.labels = self.load_csv('images.csv')

    #二、创建一个csv,用于保存图片全路径和对应的标签label
    #这个函数接受一个参数filename
    #这个函数中需要将所有图片都load进来
    def load_csv(self, filename):
        # 需要一个判断,如果文件不存在,就需要创建csv,直接读取创建好的csv文件内容即可:
        #如果不存在,就需要创建csv
        if not os.path.exists(os.path.join(self.root, filename)):

            images = []
            for name in self.name2label.keys():
                #类别信息我们可以使用路径来判断
                #比如:'pokemon\\mewtwo\\00001.png'
                #上面路径的mewtwo就是类别
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            #1167张, 'pokemon\\bulbasaur\\00000000.png'
            print(len(images), images)

            #将images顺序打乱
            random.shuffle(images)

            #打开这个文件
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                #新建writer,写入csv这个文件对象
                writer = csv.writer(f)
                for img in images: #获得每行信息'pokemon\\bulbasaur\\00000000.png'
                    #通过分割符,将每行信息的内容分割开,取导数第二个,类型
                    name = img.split(os.sep)[-2]

                    #通过获取的类型名来获取label
                    label = self.name2label[name]

                    #将这个label信息写到csv中
                    #csv是以逗号作为分割的
                    #形式为:'pokemon\\bulbasaur\\00000000.png', 0
                    writer.writerow([img,label])
                print('writen into csv file:',filename)

        #三、读取csv文件过程:
        #这里需要在开头有一个判断,如果csv存在,就不用写入csv了,直接进行读取
        #下次运行的时候只需加载进来即可
        images,labels = [],[]
        with open(os.path.join(self.root, filename)) as f:
            # 新建reader,读取csv这个文件对象
            reader = csv.reader(f)
            for row in reader:
                img, label =row #解包出来:'pokemon\\bulbasaur\\00000000.png', 0
                label = int(label) #将这个label转码为int类型
                #将img每个图片路径,以及label保存在建立好的列表对象中。
                images.append(img)
                labels.append(label)
        print(images[:4])
        print(labels[:4])
        print('len(images):',len(images))
        print('len(labels):',len(labels))
        assert len(images) == len(labels)
        print('read csv file:', filename)
        return images, labels

    # 三、完成两个自定义的逻辑
    # 1、样本的总体数量(图片总体数量),返回的是一个数字,总体图片大概有1168张,60%用于training,因此返回6-7百张图片
    def __len__(self):
        pass

    # 2、用于返回当前index上面元素的值,这里是返回两个数据:
    # 需要返回当前image的data,以及image所对应的label[0,1,2,3,4]
    def __getitem__(self, idx):
        pass


#创建一个调试函数:
def main():
    db = Pokemon('pokemon',224,'train')


if __name__ == '__main__':
    main()

在这里插入图片描述

四、不同比例模式下对图片数量进行划分

training模式下,取所有图片的60%作为样本。
validation模式下,取剩余图片的20%作为样本。
test模式下,取剩下的20%作为样本。

import torch
import os, glob
import random, csv

# 所有自定义数据集的一个母类
from torch.utils.data import Dataset, DataLoader
#常用的图片变换器
from torchvision import transforms
#从图片读取出数据
from PIL import Image

# 自定义数据集的类,继承自Dataset
class Pokemon(Dataset):
    # 一、初始化函数init
    # 第一个参数root:总的图片所在的位置,可以是任意的位置,我们的图片可以放在任意的位置,我们这里就存储在当前目录pokemon文件夹下。
    # 第二个参数resize:图片输出的size,是由这个参数所进行设定。
    # 第三个参数mode:这里我们需要做train、validation以及test,对应这三种数据结构,因此我们用一个list[0,1,2]来代表是哪个模式。
    def __init__(self, root, resize, mode):
        # 先调用母类的初始化函数:
        super(Pokemon, self).__init__()
        #1、首先我们将这个参数保存下来
        self.root = root
        self.resize = resize

        #2、给每一个分类做一个映射,即当前的皮卡丘、妙蛙种子等这个string类型所对应的label是多少,这个是需要我们人为进行编码的。
        self.name2label = {} #用字典来表示映射关系

        #通过循环方式,将root路径下的文件夹名进行编码
        for name in sorted(os.listdir(os.path.join(root))):
            #过滤掉非文件夹:如果不是dir,就过滤掉,此外我们还通过sorted排序的方法,将键值对关系固定下来
            if not os.path.isdir(os.path.join(root,name)):
                continue
            #文件名做key,当前name2label的长度做value
            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)

        #将self.load_csv的返回值images, labels赋予self.images, self.labels
        self.images, self.labels = self.load_csv('images.csv')

        # 四、不同比例模式下对图片数量进行划分
        if mode == 'train': #取60%做training
            #len(self.images)的长度是1167,取60%做为train模式的图片
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif mode =='val': #取20%做validation, 60%-80%
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else: #mode为test,取80%到最末尾
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]

    #二、创建一个csv,用于保存图片全路径和对应的标签label
    #这个函数接受一个参数filename
    #这个函数中需要将所有图片都load进来
    def load_csv(self, filename):
        # 需要一个判断,如果文件不存在,就需要创建csv,直接读取创建好的csv文件内容即可:
        #如果不存在,就需要创建csv
        if not os.path.exists(os.path.join(self.root, filename)):

            images = []
            for name in self.name2label.keys():
                #类别信息我们可以使用路径来判断
                #比如:'pokemon\\mewtwo\\00001.png'
                #上面路径的mewtwo就是类别
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            #1167张, 'pokemon\\bulbasaur\\00000000.png'
            print(len(images), images)

            #将images顺序打乱
            random.shuffle(images)

            #打开这个文件
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                #新建writer,写入csv这个文件对象
                writer = csv.writer(f)
                for img in images: #获得每行信息'pokemon\\bulbasaur\\00000000.png'
                    #通过分割符,将每行信息的内容分割开,取导数第二个,类型
                    name = img.split(os.sep)[-2]

                    #通过获取的类型名来获取label
                    label = self.name2label[name]

                    #将这个label信息写到csv中
                    #csv是以逗号作为分割的
                    #形式为:'pokemon\\bulbasaur\\00000000.png', 0
                    writer.writerow([img,label])
                print('writen into csv file:',filename)

        #三、读取csv文件过程:
        #这里需要在开头有一个判断,如果csv存在,就不用写入csv了,直接进行读取
        #下次运行的时候只需加载进来即可
        images,labels = [],[]
        with open(os.path.join(self.root, filename)) as f:
            # 新建reader,读取csv这个文件对象
            reader = csv.reader(f)
            for row in reader:
                img, label =row #解包出来:'pokemon\\bulbasaur\\00000000.png', 0
                label = int(label) #将这个label转码为int类型
                #将img每个图片路径,以及label保存在建立好的列表对象中。
                images.append(img)
                labels.append(label)
        print(images[:4])
        print(labels[:4])
        print('len(images):',len(images))
        print('len(labels):',len(labels))
        assert len(images) == len(labels)
        print('read csv file:', filename)
        return images, labels

    # 三、完成两个自定义的逻辑
    # 1、样本的总体数量(图片总体数量),返回的是一个数字,总体图片大概有1168张,60%用于training,因此返回6-7百张图片
    def __len__(self):
        pass

    # 2、用于返回当前index上面元素的值,这里是返回两个数据:
    # 需要返回当前image的data,以及image所对应的label[0,1,2,3,4]
    def __getitem__(self, idx):
        pass


#创建一个调试函数:
def main():
    db = Pokemon('pokemon',224,'train')


if __name__ == '__main__':
    main()

五、完成总体样本数量函数的内容

import torch
import os, glob
import random, csv

# 所有自定义数据集的一个母类
from torch.utils.data import Dataset, DataLoader
#常用的图片变换器
from torchvision import transforms
#从图片读取出数据
from PIL import Image

# 自定义数据集的类,继承自Dataset
class Pokemon(Dataset):
    # 一、初始化函数init
    # 第一个参数root:总的图片所在的位置,可以是任意的位置,我们的图片可以放在任意的位置,我们这里就存储在当前目录pokemon文件夹下。
    # 第二个参数resize:图片输出的size,是由这个参数所进行设定。
    # 第三个参数mode:这里我们需要做train、validation以及test,对应这三种数据结构,因此我们用一个list[0,1,2]来代表是哪个模式。
    def __init__(self, root, resize, mode):
        # 先调用母类的初始化函数:
        super(Pokemon, self).__init__()
        #1、首先我们将这个参数保存下来
        self.root = root
        self.resize = resize

        #2、给每一个分类做一个映射,即当前的皮卡丘、妙蛙种子等这个string类型所对应的label是多少,这个是需要我们人为进行编码的。
        self.name2label = {} #用字典来表示映射关系

        #通过循环方式,将root路径下的文件夹名进行编码
        for name in sorted(os.listdir(os.path.join(root))):
            #过滤掉非文件夹:如果不是dir,就过滤掉,此外我们还通过sorted排序的方法,将键值对关系固定下来
            if not os.path.isdir(os.path.join(root,name)):
                continue
            #文件名做key,当前name2label的长度做value
            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)

        #将self.load_csv的返回值images, labels赋予self.images, self.labels
        self.images, self.labels = self.load_csv('images.csv')

        # 四、不同比例模式下对图片数量进行划分
        if mode == 'train': #取60%做training
            #len(self.images)的长度是1167,取60%做为train模式的图片
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif mode =='val': #取20%做validation, 60%-80%
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else: #mode为test,取80%到最末尾
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]

    #二、创建一个csv,用于保存图片全路径和对应的标签label
    #这个函数接受一个参数filename
    #这个函数中需要将所有图片都load进来
    def load_csv(self, filename):
        # 需要一个判断,如果文件不存在,就需要创建csv,直接读取创建好的csv文件内容即可:
        #如果不存在,就需要创建csv
        if not os.path.exists(os.path.join(self.root, filename)):

            images = []
            for name in self.name2label.keys():
                #类别信息我们可以使用路径来判断
                #比如:'pokemon\\mewtwo\\00001.png'
                #上面路径的mewtwo就是类别
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            #1167张, 'pokemon\\bulbasaur\\00000000.png'
            print(len(images), images)

            #将images顺序打乱
            random.shuffle(images)

            #打开这个文件
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                #新建writer,写入csv这个文件对象
                writer = csv.writer(f)
                for img in images: #获得每行信息'pokemon\\bulbasaur\\00000000.png'
                    #通过分割符,将每行信息的内容分割开,取导数第二个,类型
                    name = img.split(os.sep)[-2]

                    #通过获取的类型名来获取label
                    label = self.name2label[name]

                    #将这个label信息写到csv中
                    #csv是以逗号作为分割的
                    #形式为:'pokemon\\bulbasaur\\00000000.png', 0
                    writer.writerow([img,label])
                print('writen into csv file:',filename)

        #三、读取csv文件过程:
        #这里需要在开头有一个判断,如果csv存在,就不用写入csv了,直接进行读取
        #下次运行的时候只需加载进来即可
        images,labels = [],[]
        with open(os.path.join(self.root, filename)) as f:
            # 新建reader,读取csv这个文件对象
            reader = csv.reader(f)
            for row in reader:
                img, label =row #解包出来:'pokemon\\bulbasaur\\00000000.png', 0
                label = int(label) #将这个label转码为int类型
                #将img每个图片路径,以及label保存在建立好的列表对象中。
                images.append(img)
                labels.append(label)
        print(images[:4])
        print(labels[:4])
        print('len(images):',len(images))
        print('len(labels):',len(labels))
        assert len(images) == len(labels)
        print('read csv file:', filename)
        return images, labels

    # 完成两个自定义的逻辑:
    # 1、样本的总体数量(图片总体数量),返回的是一个数字,总体图片大概有1168张,60%用于training,因此返回6-7百张图片
    # 五、完成总体样本数量
    def __len__(self):
        #这里的样本长度是跟模型类别来决定的,上面已经根据不同模型类型划分了样本数量了。
        #不同模式下,样本长度是不同的。
        #因此这里的总体样本长度,就是不同模式下的样本数量。
        return len(self.images)

    # 2、用于返回当前index上面元素的值,这里是返回两个数据:
    # 需要返回当前image的data,以及image所对应的label[0,1,2,3,4]
    def __getitem__(self, idx):
        pass


#创建一个调试函数:
def main():
    db = Pokemon('pokemon',224,'train')


if __name__ == '__main__':
    main()

六、完成image、label 与 index 索引的一 一对应。

import torch
import os, glob
import random, csv

# 所有自定义数据集的一个母类
from torch.utils.data import Dataset, DataLoader

# 常用的图片变换器
from torchvision import transforms
# 从图片读取出数据
from PIL import Image


# 自定义数据集的类,继承自Dataset
class Pokemon(Dataset):
    # 一、初始化函数init
    # 第一个参数root:总的图片所在的位置,可以是任意的位置,我们的图片可以放在任意的位置,我们这里就存储在当前目录pokemon文件夹下。
    # 第二个参数resize:图片输出的size,是由这个参数所进行设定。
    # 第三个参数mode:这里我们需要做train、validation以及test,对应这三种数据结构,因此我们用一个list[0,1,2]来代表是哪个模式。
    def __init__(self, root, resize, mode):
        # 先调用母类的初始化函数:
        super(Pokemon, self).__init__()
        # 1、首先我们将这个参数保存下来
        self.root = root
        self.resize = resize

        # 2、给每一个分类做一个映射,即当前的皮卡丘、妙蛙种子等这个string类型所对应的label是多少,这个是需要我们人为进行编码的。
        self.name2label = {}  # 用字典来表示映射关系

        # 通过循环方式,将root路径下的文件夹名进行编码
        for name in sorted(os.listdir(os.path.join(root))):
            # 过滤掉非文件夹:如果不是dir,就过滤掉,此外我们还通过sorted排序的方法,将键值对关系固定下来
            if not os.path.isdir(os.path.join(root, name)):
                continue
            # 文件名做key,当前name2label的长度做value
            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)

        # 将self.load_csv的返回值images, labels赋予self.images, self.labels
        self.images, self.labels = self.load_csv('images.csv')

        # 四、不同比例模式下对图片数量进行划分
        if mode == 'train':  # 取60%做training
            # len(self.images)的长度是1167,取60%做为train模式的图片
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'val':  # 取20%做validation, 60%-80%
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:  # mode为test,取80%到最末尾
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

    # 二、创建一个csv,用于保存图片全路径和对应的标签label
    # 这个函数接受一个参数filename
    # 这个函数中需要将所有图片都load进来
    def load_csv(self, filename):
        # 需要一个判断,如果文件不存在,就需要创建csv,直接读取创建好的csv文件内容即可:
        # 如果不存在,就需要创建csv
        if not os.path.exists(os.path.join(self.root, filename)):

            images = []
            for name in self.name2label.keys():
                # 类别信息我们可以使用路径来判断
                # 比如:'pokemon\\mewtwo\\00001.png'
                # 上面路径的mewtwo就是类别
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            # 1167张, 'pokemon\\bulbasaur\\00000000.png'
            print(len(images), images)

            # 将images顺序打乱
            random.shuffle(images)

            # 打开这个文件
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                # 新建writer,写入csv这个文件对象
                writer = csv.writer(f)
                for img in images:  # 获得每行信息'pokemon\\bulbasaur\\00000000.png'
                    # 通过分割符,将每行信息的内容分割开,取导数第二个,类型
                    name = img.split(os.sep)[-2]

                    # 通过获取的类型名来获取label
                    label = self.name2label[name]

                    # 将这个label信息写到csv中
                    # csv是以逗号作为分割的
                    # 形式为:'pokemon\\bulbasaur\\00000000.png', 0
                    writer.writerow([img, label])
                print('writen into csv file:', filename)

        # 三、读取csv文件过程:
        # 这里需要在开头有一个判断,如果csv存在,就不用写入csv了,直接进行读取
        # 下次运行的时候只需加载进来即可
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            # 新建reader,读取csv这个文件对象
            reader = csv.reader(f)
            for row in reader:
                img, label = row  # 解包出来:'pokemon\\bulbasaur\\00000000.png', 0
                label = int(label)  # 将这个label转码为int类型
                # 将img每个图片路径,以及label保存在建立好的列表对象中。
                images.append(img)
                labels.append(label)
        print(images[:4])
        print(labels[:4])
        print('len(images):', len(images))
        print('len(labels):', len(labels))
        assert len(images) == len(labels)
        print('read csv file:', filename)
        return images, labels

    # 完成两个自定义的逻辑:
    # 1、样本的总体数量(图片总体数量),返回的是一个数字,总体图片大概有1168张,60%用于training,因此返回6-7百张图片
    # 五、完成总体样本数量函数的内容
    def __len__(self):
        # 这里的样本长度是跟模型类别来决定的,上面已经根据不同模型类型划分了样本数量了。
        # 不同模式下,样本长度是不同的。
        # 因此这里的总体样本长度,就是不同模式下的样本数量。
        return len(self.images)

    # 2、用于返回当前index上面元素的值,这里是返回两个数据:
    # 需要返回当前image的data,以及image所对应的label[0,1,2,3,4]
    # 六、完成index与样本的一一对应
    def __getitem__(self, idx):
        # idx数值范围是[0-len(images)]
        # self.images保存了所有的数据;self.labels保存了所有数据对应的label信息;
        # img是一个string类型(还不是具体的图片,只是路径):'pokemon\\bulbasaur\\00000000.png'
        # label是一个整数类型
        img, label = self.images[idx], self.labels[idx]

        # 这里就需要将img所对应的路径读取出图片,并转为tensor类型
        # 这里我们可以Compose组合操作步骤
        tf = transforms.Compose([
            # 这里需要将路径变成具体的图片数据类型
            # 即:string path => image data
            lambda x: Image.open(x).convert('RGB'),
            # Resize工作,这里的size是我们实例化时的self.resize的值
            transforms.Resize((self.resize, self.resize)),
            # 将数据变为tensor类型
            transforms.ToTensor()
        ])
        img = tf(img)
        label = torch.tensor(label)
        return img, label


# 创建一个调试函数:
def main():
    db = Pokemon('pokemon', 224, 'train')


if __name__ == '__main__':
    main()

以上的六个步骤基本上完成了自定义数据集的工作。

七、这里我们来完成一个验证工作,检验一下自定义数据集是否能够成功加载

我们需要先启动一下visdom,这里我们在PyCharm的终端输入:

python -m visdom.server  

就可以启动visdom了,我们获得了一个空白的visdom的监控界面。
在这里插入图片描述
在这里插入图片描述

import torch
import os, glob
import random, csv

# 所有自定义数据集的一个母类
from torch.utils.data import Dataset, DataLoader

# 常用的图片变换器
from torchvision import transforms
# 从图片读取出数据
from PIL import Image


# 自定义数据集的类,继承自Dataset
class Pokemon(Dataset):
    # 一、初始化函数init
    # 第一个参数root:总的图片所在的位置,可以是任意的位置,我们的图片可以放在任意的位置,我们这里就存储在当前目录pokemon文件夹下。
    # 第二个参数resize:图片输出的size,是由这个参数所进行设定。
    # 第三个参数mode:这里我们需要做train、validation以及test,对应这三种数据结构,因此我们用一个list[0,1,2]来代表是哪个模式。
    def __init__(self, root, resize, mode):
        # 先调用母类的初始化函数:
        super(Pokemon, self).__init__()
        # 1、首先我们将这个参数保存下来
        self.root = root
        self.resize = resize

        # 2、给每一个分类做一个映射,即当前的皮卡丘、妙蛙种子等这个string类型所对应的label是多少,这个是需要我们人为进行编码的。
        self.name2label = {}  # 用字典来表示映射关系

        # 通过循环方式,将root路径下的文件夹名进行编码
        for name in sorted(os.listdir(os.path.join(root))):
            # 过滤掉非文件夹:如果不是dir,就过滤掉,此外我们还通过sorted排序的方法,将键值对关系固定下来
            if not os.path.isdir(os.path.join(root, name)):
                continue
            # 文件名做key,当前name2label的长度做value
            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)

        # 将self.load_csv的返回值images, labels赋予self.images, self.labels
        self.images, self.labels = self.load_csv('images.csv')

        # 四、不同比例模式下对图片数量进行划分
        if mode == 'train':  # 取60%做training
            # len(self.images)的长度是1167,取60%做为train模式的图片
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'val':  # 取20%做validation, 60%-80%
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:  # mode为test,取80%到最末尾
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

    # 二、创建一个csv,用于保存图片全路径和对应的标签label
    # 这个函数接受一个参数filename
    # 这个函数中需要将所有图片都load进来
    def load_csv(self, filename):
        # 需要一个判断,如果文件不存在,就需要创建csv,直接读取创建好的csv文件内容即可:
        # 如果不存在,就需要创建csv
        if not os.path.exists(os.path.join(self.root, filename)):

            images = []
            for name in self.name2label.keys():
                # 类别信息我们可以使用路径来判断
                # 比如:'pokemon\\mewtwo\\00001.png'
                # 上面路径的mewtwo就是类别
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            # 1167张, 'pokemon\\bulbasaur\\00000000.png'
            print(len(images), images)

            # 将images顺序打乱
            random.shuffle(images)

            # 打开这个文件
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                # 新建writer,写入csv这个文件对象
                writer = csv.writer(f)
                for img in images:  # 获得每行信息'pokemon\\bulbasaur\\00000000.png'
                    # 通过分割符,将每行信息的内容分割开,取导数第二个,类型
                    name = img.split(os.sep)[-2]

                    # 通过获取的类型名来获取label
                    label = self.name2label[name]

                    # 将这个label信息写到csv中
                    # csv是以逗号作为分割的
                    # 形式为:'pokemon\\bulbasaur\\00000000.png', 0
                    writer.writerow([img, label])
                print('writen into csv file:', filename)

        # 三、读取csv文件过程:
        # 这里需要在开头有一个判断,如果csv存在,就不用写入csv了,直接进行读取
        # 下次运行的时候只需加载进来即可
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            # 新建reader,读取csv这个文件对象
            reader = csv.reader(f)
            for row in reader:
                img, label = row  # 解包出来:'pokemon\\bulbasaur\\00000000.png', 0
                label = int(label)  # 将这个label转码为int类型
                # 将img每个图片路径,以及label保存在建立好的列表对象中。
                images.append(img)
                labels.append(label)
        print(images[:4])
        print(labels[:4])
        print('len(images):', len(images))
        print('len(labels):', len(labels))
        assert len(images) == len(labels)
        print('read csv file:', filename)
        return images, labels

    # 完成两个自定义的逻辑:
    # 1、样本的总体数量(图片总体数量),返回的是一个数字,总体图片大概有1168张,60%用于training,因此返回6-7百张图片
    # 五、完成总体样本数量函数的内容
    def __len__(self):
        # 这里的样本长度是跟模型类别来决定的,上面已经根据不同模型类型划分了样本数量了。
        # 不同模式下,样本长度是不同的。
        # 因此这里的总体样本长度,就是不同模式下的样本数量。
        return len(self.images)

    # 2、用于返回当前index上面元素的值,这里是返回两个数据:
    # 需要返回当前image的data,以及image所对应的label[0,1,2,3,4]
    # 六、完成index与样本的一一对应
    def __getitem__(self, idx):
        # idx数值范围是[0-len(images)]
        # self.images保存了所有的数据;self.labels保存了所有数据对应的label信息;
        # img是一个string类型(还不是具体的图片,只是路径):'pokemon\\bulbasaur\\00000000.png'
        # label是一个整数类型
        img, label = self.images[idx], self.labels[idx]

        # 这里就需要将img所对应的路径读取出图片,并转为tensor类型
        # 这里我们可以Compose组合操作步骤
        tf = transforms.Compose([
            # 这里需要将路径变成具体的图片数据类型
            # 即:string path => image data
            lambda x: Image.open(x).convert('RGB'),
            # Resize工作,这里的size是我们实例化时的self.resize的值
            transforms.Resize((self.resize, self.resize)),
            # 将数据变为tensor类型
            transforms.ToTensor()
        ])
        img = tf(img)
        label = torch.tensor(label)
        return img, label


# 创建一个调试函数:
def main():
    # 七、验证自定义数据集
    #验证需要一些辅助函数,用visdom做一些可视化。
    import visdom
    import time

    #创建一个visdom这个对象
    viz = visdom.Visdom()

    db = Pokemon('pokemon', 224, 'train')

    # 首先是可视化一个样本
    # next() 返回迭代器的下一个项目。
    # next()函数要和生成迭代器的iter()函数一起使用。
    x,y = next(iter(db))
    print('sample:', x.shape,y.shape,y)

    #将图片可视化一下
    viz.image(x, win='sample_x', opts=dict(title='sample_x'))

if __name__ == '__main__':
    main()

这样我们的一个超梦样本就加载出来了。
在这里插入图片描述
在这里插入图片描述

八、数据预处理的工作

增加数据预处理的工作,在Compose中增加放大、旋转、裁切这三个数据增强的操作,data augmentation数据增强。

import torch
import os, glob
import random, csv

# 所有自定义数据集的一个母类
from torch.utils.data import Dataset, DataLoader

# 常用的图片变换器
from torchvision import transforms
# 从图片读取出数据
from PIL import Image


# 自定义数据集的类,继承自Dataset
class Pokemon(Dataset):
    # 一、初始化函数init
    # 第一个参数root:总的图片所在的位置,可以是任意的位置,我们的图片可以放在任意的位置,我们这里就存储在当前目录pokemon文件夹下。
    # 第二个参数resize:图片输出的size,是由这个参数所进行设定。
    # 第三个参数mode:这里我们需要做train、validation以及test,对应这三种数据结构,因此我们用一个list[0,1,2]来代表是哪个模式。
    def __init__(self, root, resize, mode):
        # 先调用母类的初始化函数:
        super(Pokemon, self).__init__()
        # 1、首先我们将这个参数保存下来
        self.root = root
        self.resize = resize

        # 2、给每一个分类做一个映射,即当前的皮卡丘、妙蛙种子等这个string类型所对应的label是多少,这个是需要我们人为进行编码的。
        self.name2label = {}  # 用字典来表示映射关系

        # 通过循环方式,将root路径下的文件夹名进行编码
        for name in sorted(os.listdir(os.path.join(root))):
            # 过滤掉非文件夹:如果不是dir,就过滤掉,此外我们还通过sorted排序的方法,将键值对关系固定下来
            if not os.path.isdir(os.path.join(root, name)):
                continue
            # 文件名做key,当前name2label的长度做value
            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)

        # 将self.load_csv的返回值images, labels赋予self.images, self.labels
        self.images, self.labels = self.load_csv('images.csv')

        # 四、不同比例模式下对图片数量进行划分
        if mode == 'train':  # 取60%做training
            # len(self.images)的长度是1167,取60%做为train模式的图片
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'val':  # 取20%做validation, 60%-80%
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:  # mode为test,取80%到最末尾
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

    # 二、创建一个csv,用于保存图片全路径和对应的标签label
    # 这个函数接受一个参数filename
    # 这个函数中需要将所有图片都load进来
    def load_csv(self, filename):
        # 需要一个判断,如果文件不存在,就需要创建csv,直接读取创建好的csv文件内容即可:
        # 如果不存在,就需要创建csv
        if not os.path.exists(os.path.join(self.root, filename)):

            images = []
            for name in self.name2label.keys():
                # 类别信息我们可以使用路径来判断
                # 比如:'pokemon\\mewtwo\\00001.png'
                # 上面路径的mewtwo就是类别
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            # 1167张, 'pokemon\\bulbasaur\\00000000.png'
            print(len(images), images)

            # 将images顺序打乱
            random.shuffle(images)

            # 打开这个文件
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                # 新建writer,写入csv这个文件对象
                writer = csv.writer(f)
                for img in images:  # 获得每行信息'pokemon\\bulbasaur\\00000000.png'
                    # 通过分割符,将每行信息的内容分割开,取导数第二个,类型
                    name = img.split(os.sep)[-2]

                    # 通过获取的类型名来获取label
                    label = self.name2label[name]

                    # 将这个label信息写到csv中
                    # csv是以逗号作为分割的
                    # 形式为:'pokemon\\bulbasaur\\00000000.png', 0
                    writer.writerow([img, label])
                print('writen into csv file:', filename)

        # 三、读取csv文件过程:
        # 这里需要在开头有一个判断,如果csv存在,就不用写入csv了,直接进行读取
        # 下次运行的时候只需加载进来即可
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            # 新建reader,读取csv这个文件对象
            reader = csv.reader(f)
            for row in reader:
                img, label = row  # 解包出来:'pokemon\\bulbasaur\\00000000.png', 0
                label = int(label)  # 将这个label转码为int类型
                # 将img每个图片路径,以及label保存在建立好的列表对象中。
                images.append(img)
                labels.append(label)
        print(images[:4])
        print(labels[:4])
        print('len(images):', len(images))
        print('len(labels):', len(labels))
        assert len(images) == len(labels)
        print('read csv file:', filename)
        return images, labels

    # 完成两个自定义的逻辑:
    # 1、样本的总体数量(图片总体数量),返回的是一个数字,总体图片大概有1168张,60%用于training,因此返回6-7百张图片
    # 五、完成总体样本数量函数的内容
    def __len__(self):
        # 这里的样本长度是跟模型类别来决定的,上面已经根据不同模型类型划分了样本数量了。
        # 不同模式下,样本长度是不同的。
        # 因此这里的总体样本长度,就是不同模式下的样本数量。
        return len(self.images)

    # 2、用于返回当前index上面元素的值,这里是返回两个数据:
    # 需要返回当前image的data,以及image所对应的label[0,1,2,3,4]
    # 六、完成index与样本的一一对应
    def __getitem__(self, idx):
        # idx数值范围是[0-len(images)]
        # self.images保存了所有的数据;self.labels保存了所有数据对应的label信息;
        # img是一个string类型(还不是具体的图片,只是路径):'pokemon\\bulbasaur\\00000000.png'
        # label是一个整数类型
        img, label = self.images[idx], self.labels[idx]

        # 这里就需要将img所对应的路径读取出图片,并转为tensor类型
        # 这里我们可以Compose组合操作步骤
        # 八、增加数据预处理的工作(数据增强),在Compose中增加这些内容,data augmentation数据增强
        # 这里我们做放大、旋转、裁切这三个数据增强的操作
        tf = transforms.Compose([
            # 这里需要将路径变成具体的图片数据类型
            # 即:string path => image data
            lambda x: Image.open(x).convert('RGB'),
            # Resize工作,这里的size是我们实例化时的self.resize的值
            # 1、data augmentation放大:在Resize设置的基础上,稍微调大一些size, 调整为1.25倍
            transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
            # 2、data augmentation旋转:增加随机旋转,注意:这里旋转角度不能太大,会增加学习的难度。
            transforms.RandomRotation(15),
            # 3、data augmentation中心裁切:裁切为我们所需要的大小
            transforms.CenterCrop(self.resize),
            # 将数据变为tensor类型
            transforms.ToTensor()
            # 4、normalize处理,希望图片数值范围在0左右分布,而不希望数值只分布在0的右侧或只在左侧
            #其中参数统计的所有image net数据集几百万张图片的mean=[R的mean,G的mean,B的mean]和std=[R的方差,G的方差,B的方差]
            #基本上这个数值是通用的
            #数据通过Normalize处理后,就是在-1到1之间分布了。
            transforms.Normalize(mean = [0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        img = tf(img)
        label = torch.tensor(label)
        return img, label


# 创建一个调试函数:
def main():
    # 七、验证自定义数据集
    #验证需要一些辅助函数,用visdom做一些可视化。
    import visdom
    import time

    #创建一个visdom这个对象
    viz = visdom.Visdom()

    db = Pokemon('pokemon', 224, 'train')

    # 首先是可视化一个样本
    # next() 返回迭代器的下一个项目。
    # next()函数要和生成迭代器的iter()函数一起使用。
    x,y = next(iter(db))
    print('sample:', x.shape,y.shape,y)

    #将图片可视化一下
    viz.image(x, win='sample_x', opts=dict(title='sample_x'))

if __name__ == '__main__':
    main()

下图结果是未添加 transforms.Normalize(mean = [0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])的效果:发现图片效果正常
在这里插入图片描述
下图是添加了 transforms.Normalize(mean = [0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])的结果:
这是因为visdom接受的范围是0到1,通过normalize后,就是从-1到1了。
在这里插入图片描述

九、解决normalize处理后,visdom无法正常显示的问题

因为visdom接受的范围是0到1,通过normalize后,就是从-1到1了。因此需要denormalize

import torch
import os, glob
import random, csv

# 所有自定义数据集的一个母类
from torch.utils.data import Dataset, DataLoader

# 常用的图片变换器
from torchvision import transforms
# 从图片读取出数据
from PIL import Image


# 自定义数据集的类,继承自Dataset
class Pokemon(Dataset):
    # 一、初始化函数init
    # 第一个参数root:总的图片所在的位置,可以是任意的位置,我们的图片可以放在任意的位置,我们这里就存储在当前目录pokemon文件夹下。
    # 第二个参数resize:图片输出的size,是由这个参数所进行设定。
    # 第三个参数mode:这里我们需要做train、validation以及test,对应这三种数据结构,因此我们用一个list[0,1,2]来代表是哪个模式。
    def __init__(self, root, resize, mode):
        # 先调用母类的初始化函数:
        super(Pokemon, self).__init__()
        # 1、首先我们将这个参数保存下来
        self.root = root
        self.resize = resize

        # 2、给每一个分类做一个映射,即当前的皮卡丘、妙蛙种子等这个string类型所对应的label是多少,这个是需要我们人为进行编码的。
        self.name2label = {}  # 用字典来表示映射关系

        # 通过循环方式,将root路径下的文件夹名进行编码
        for name in sorted(os.listdir(os.path.join(root))):
            # 过滤掉非文件夹:如果不是dir,就过滤掉,此外我们还通过sorted排序的方法,将键值对关系固定下来
            if not os.path.isdir(os.path.join(root, name)):
                continue
            # 文件名做key,当前name2label的长度做value
            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)

        # 将self.load_csv的返回值images, labels赋予self.images, self.labels
        self.images, self.labels = self.load_csv('images.csv')

        # 四、不同比例模式下对图片数量进行划分
        if mode == 'train':  # 取60%做training
            # len(self.images)的长度是1167,取60%做为train模式的图片
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'val':  # 取20%做validation, 60%-80%
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:  # mode为test,取80%到最末尾
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

    # 二、创建一个csv,用于保存图片全路径和对应的标签label
    # 这个函数接受一个参数filename
    # 这个函数中需要将所有图片都load进来
    def load_csv(self, filename):
        # 需要一个判断,如果文件不存在,就需要创建csv,直接读取创建好的csv文件内容即可:
        # 如果不存在,就需要创建csv
        if not os.path.exists(os.path.join(self.root, filename)):

            images = []
            for name in self.name2label.keys():
                # 类别信息我们可以使用路径来判断
                # 比如:'pokemon\\mewtwo\\00001.png'
                # 上面路径的mewtwo就是类别
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            # 1167张, 'pokemon\\bulbasaur\\00000000.png'
            print(len(images), images)

            # 将images顺序打乱
            random.shuffle(images)

            # 打开这个文件
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                # 新建writer,写入csv这个文件对象
                writer = csv.writer(f)
                for img in images:  # 获得每行信息'pokemon\\bulbasaur\\00000000.png'
                    # 通过分割符,将每行信息的内容分割开,取导数第二个,类型
                    name = img.split(os.sep)[-2]

                    # 通过获取的类型名来获取label
                    label = self.name2label[name]

                    # 将这个label信息写到csv中
                    # csv是以逗号作为分割的
                    # 形式为:'pokemon\\bulbasaur\\00000000.png', 0
                    writer.writerow([img, label])
                print('writen into csv file:', filename)

        # 三、读取csv文件过程:
        # 这里需要在开头有一个判断,如果csv存在,就不用写入csv了,直接进行读取
        # 下次运行的时候只需加载进来即可
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            # 新建reader,读取csv这个文件对象
            reader = csv.reader(f)
            for row in reader:
                img, label = row  # 解包出来:'pokemon\\bulbasaur\\00000000.png', 0
                label = int(label)  # 将这个label转码为int类型
                # 将img每个图片路径,以及label保存在建立好的列表对象中。
                images.append(img)
                labels.append(label)
        print(images[:4])
        print(labels[:4])
        print('len(images):', len(images))
        print('len(labels):', len(labels))
        assert len(images) == len(labels)
        print('read csv file:', filename)
        return images, labels

    # 完成两个自定义的逻辑:
    # 1、样本的总体数量(图片总体数量),返回的是一个数字,总体图片大概有1168张,60%用于training,因此返回6-7百张图片
    # 五、完成总体样本数量函数的内容
    def __len__(self):
        # 这里的样本长度是跟模型类别来决定的,上面已经根据不同模型类型划分了样本数量了。
        # 不同模式下,样本长度是不同的。
        # 因此这里的总体样本长度,就是不同模式下的样本数量。
        return len(self.images)

    # 九、解决normalize处理后,visdom无法正常显示的问题
    # 这里传入的参数x是normalize过后的
    def denormalize(self, x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # normalize的过程:
        # x_hat = (x-mean)/std
        # x = x_hat*std+mean 这个就是逆操作过程
        # x = [c,h,w]
        # mean = [3] => boardcasting后
        # unsqueeze()函数起升维的作用,参数表示在哪个地方加一个维度
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        print('mean.shape,std.shape:',mean.shape,std.shape)
        x = x_hat * std + mean
        return x

    # 2、用于返回当前index上面元素的值,这里是返回两个数据:
    # 需要返回当前image的data,以及image所对应的label[0,1,2,3,4]
    # 六、完成index与样本的一一对应
    def __getitem__(self, idx):
        # idx数值范围是[0-len(images)]
        # self.images保存了所有的数据;self.labels保存了所有数据对应的label信息;
        # img是一个string类型(还不是具体的图片,只是路径):'pokemon\\bulbasaur\\00000000.png'
        # label是一个整数类型
        img, label = self.images[idx], self.labels[idx]

        # 这里就需要将img所对应的路径读取出图片,并转为tensor类型
        # 这里我们可以Compose组合操作步骤
        # 八、增加数据预处理的工作,在Compose中增加这些内容,data augmentation数据增强
        # 这里我们做放大、旋转、裁切这三个数据增强的操作
        tf = transforms.Compose([
            # 这里需要将路径变成具体的图片数据类型
            # 即:string path => image data
            lambda x: Image.open(x).convert('RGB'),
            # Resize工作,这里的size是我们实例化时的self.resize的值
            # 1、data augmentation放大:在Resize设置的基础上,稍微调大一些size, 调整为1.25倍
            transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
            # 2、data augmentation旋转:增加随机旋转,注意:这里旋转角度不能太大,会增加学习的难度。
            transforms.RandomRotation(15),
            # 3、data augmentation中心裁切:裁切为我们所需要的大小
            transforms.CenterCrop(self.resize),
            # 将数据变为tensor类型
            transforms.ToTensor(),
            # 4、normalize处理,希望图片数值范围在0左右分布,而不希望数值只分布在0的右侧或只在左侧
            # 其中参数统计的所有image net数据集几百万张图片的mean=[R的mean,G的mean,B的mean]和std=[R的方差,G的方差,B的方差]
            # 基本上这个数值是通用的
            # 数据通过Normalize处理后,就是在-1到1之间分布了。
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        img = tf(img)
        label = torch.tensor(label)
        return img, label


# 创建一个调试函数:
def main():
    # 七、验证自定义数据集
    # 验证需要一些辅助函数,用visdom做一些可视化。
    import visdom
    import time

    # 创建一个visdom这个对象
    viz = visdom.Visdom()

    db = Pokemon('pokemon', 224, 'train')

    # 首先是可视化一个样本
    # next() 返回迭代器的下一个项目。
    # next()函数要和生成迭代器的iter()函数一起使用。
    x, y = next(iter(db))
    print('sample:', x.shape, y.shape, y)

    # 将图片可视化一下
    #这里通过第九步创建的denormalize函数,将x的normalize逆处理一下。
    viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))


if __name__ == '__main__':
    main()

可以看到unsqueeze()函数起升维的作用,mean.shape,std.shape升维了: torch.Size([3, 1, 1]) torch.Size([3, 1, 1])
在这里插入图片描述
下图就和没有normalize处理显示的内容一样了。
在这里插入图片描述

十、实现加载batch个图片

这里我们只完成了加载一张图片的功能,实际上我们做training的时候往往我们希望加载batch个图片。

pokemon.py

import torch
import os, glob
import random, csv

# 所有自定义数据集的一个母类
from torch.utils.data import Dataset, DataLoader

# 常用的图片变换器
from torchvision import transforms
# 从图片读取出数据
from PIL import Image


# 自定义数据集的类,继承自Dataset
class Pokemon(Dataset):
    # 一、初始化函数init
    # 第一个参数root:总的图片所在的位置,可以是任意的位置,我们的图片可以放在任意的位置,我们这里就存储在当前目录pokemon文件夹下。
    # 第二个参数resize:图片输出的size,是由这个参数所进行设定。
    # 第三个参数mode:这里我们需要做train、validation以及test,对应这三种数据结构,因此我们用一个list[0,1,2]来代表是哪个模式。
    def __init__(self, root, resize, mode):
        # 先调用母类的初始化函数:
        super(Pokemon, self).__init__()
        # 1、首先我们将这个参数保存下来
        self.root = root
        self.resize = resize

        # 2、给每一个分类做一个映射,即当前的皮卡丘、妙蛙种子等这个string类型所对应的label是多少,这个是需要我们人为进行编码的。
        self.name2label = {}  # 用字典来表示映射关系

        # 通过循环方式,将root路径下的文件夹名进行编码
        for name in sorted(os.listdir(os.path.join(root))):
            # 过滤掉非文件夹:如果不是dir,就过滤掉,此外我们还通过sorted排序的方法,将键值对关系固定下来
            if not os.path.isdir(os.path.join(root, name)):
                continue
            # 文件名做key,当前name2label的长度做value
            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)

        # 将self.load_csv的返回值images, labels赋予self.images, self.labels
        self.images, self.labels = self.load_csv('images.csv')

        # 四、不同比例模式下对图片数量进行划分
        if mode == 'train':  # 取60%做training
            # len(self.images)的长度是1167,取60%做为train模式的图片
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'val':  # 取20%做validation, 60%-80%
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:  # mode为test,取80%到最末尾
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

    # 二、创建一个csv,用于保存图片全路径和对应的标签label
    # 这个函数接受一个参数filename
    # 这个函数中需要将所有图片都load进来
    def load_csv(self, filename):
        # 需要一个判断,如果文件不存在,就需要创建csv,直接读取创建好的csv文件内容即可:
        # 如果不存在,就需要创建csv
        if not os.path.exists(os.path.join(self.root, filename)):

            images = []
            for name in self.name2label.keys():
                # 类别信息我们可以使用路径来判断
                # 比如:'pokemon\\mewtwo\\00001.png'
                # 上面路径的mewtwo就是类别
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            # 1167张, 'pokemon\\bulbasaur\\00000000.png'
            print(len(images), images)

            # 将images顺序打乱
            random.shuffle(images)

            # 打开这个文件
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                # 新建writer,写入csv这个文件对象
                writer = csv.writer(f)
                for img in images:  # 获得每行信息'pokemon\\bulbasaur\\00000000.png'
                    # 通过分割符,将每行信息的内容分割开,取导数第二个,类型
                    name = img.split(os.sep)[-2]

                    # 通过获取的类型名来获取label
                    label = self.name2label[name]

                    # 将这个label信息写到csv中
                    # csv是以逗号作为分割的
                    # 形式为:'pokemon\\bulbasaur\\00000000.png', 0
                    writer.writerow([img, label])
                print('writen into csv file:', filename)

        # 三、读取csv文件过程:
        # 这里需要在开头有一个判断,如果csv存在,就不用写入csv了,直接进行读取
        # 下次运行的时候只需加载进来即可
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            # 新建reader,读取csv这个文件对象
            reader = csv.reader(f)
            for row in reader:
                img, label = row  # 解包出来:'pokemon\\bulbasaur\\00000000.png', 0
                label = int(label)  # 将这个label转码为int类型
                # 将img每个图片路径,以及label保存在建立好的列表对象中。
                images.append(img)
                labels.append(label)
        print(images[:4])
        print(labels[:4])
        print('len(images):', len(images))
        print('len(labels):', len(labels))
        assert len(images) == len(labels)
        print('read csv file:', filename)
        return images, labels

    # 完成两个自定义的逻辑:
    # 1、样本的总体数量(图片总体数量),返回的是一个数字,总体图片大概有1168张,60%用于training,因此返回6-7百张图片
    # 五、完成总体样本数量函数的内容
    def __len__(self):
        # 这里的样本长度是跟模型类别来决定的,上面已经根据不同模型类型划分了样本数量了。
        # 不同模式下,样本长度是不同的。
        # 因此这里的总体样本长度,就是不同模式下的样本数量。
        return len(self.images)

    # 九、解决normalize处理后,visdom无法正常显示的问题
    # 这里传入的参数x是normalize过后的
    def denormalize(self, x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # normalize的过程:
        # x_hat = (x-mean)/std
        # x = x_hat*std+mean 这个就是逆操作过程
        # x = [c,h,w]
        # mean = [3] => boardcasting后
        # unsqueeze()函数起升维的作用,参数表示在哪个地方加一个维度
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        print('mean.shape,std.shape:',mean.shape,std.shape)
        x = x_hat * std + mean
        return x

    # 2、用于返回当前index上面元素的值,这里是返回两个数据:
    # 需要返回当前image的data,以及image所对应的label[0,1,2,3,4]
    # 六、完成index与样本的一一对应
    def __getitem__(self, idx):
        # idx数值范围是[0-len(images)]
        # self.images保存了所有的数据;self.labels保存了所有数据对应的label信息;
        # img是一个string类型(还不是具体的图片,只是路径):'pokemon\\bulbasaur\\00000000.png'
        # label是一个整数类型
        img, label = self.images[idx], self.labels[idx]

        # 这里就需要将img所对应的路径读取出图片,并转为tensor类型
        # 这里我们可以Compose组合操作步骤
        # 八、增加数据预处理的工作,在Compose中增加这些内容,data augmentation数据增强
        # 这里我们做放大、旋转、裁切这三个数据增强的操作
        tf = transforms.Compose([
            # 这里需要将路径变成具体的图片数据类型
            # 即:string path => image data
            lambda x: Image.open(x).convert('RGB'),
            # Resize工作,这里的size是我们实例化时的self.resize的值
            # 1、data augmentation放大:在Resize设置的基础上,稍微调大一些size, 调整为1.25倍
            transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
            # 2、data augmentation旋转:增加随机旋转,注意:这里旋转角度不能太大,会增加学习的难度。
            transforms.RandomRotation(15),
            # 3、data augmentation中心裁切:裁切为我们所需要的大小
            transforms.CenterCrop(self.resize),
            # 将数据变为tensor类型
            transforms.ToTensor(),
            # 4、normalize处理,希望图片数值范围在0左右分布,而不希望数值只分布在0的右侧或只在左侧
            # 其中参数统计的所有image net数据集几百万张图片的mean=[R的mean,G的mean,B的mean]和std=[R的方差,G的方差,B的方差]
            # 基本上这个数值是通用的
            # 数据通过Normalize处理后,就是在-1到1之间分布了。
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        img = tf(img)
        label = torch.tensor(label)
        return img, label


# 创建一个调试函数:
def main():
    # 七、验证自定义数据集
    # 验证需要一些辅助函数,用visdom做一些可视化。
    import visdom
    import time

    # 创建一个visdom这个对象
    viz = visdom.Visdom()
	
    db = Pokemon('pokemon', 64, 'train') #为了更好显示一次batch加载的图片,我们将图片大小从224调整为64。

    # 首先是可视化一个样本
    # next() 返回迭代器的下一个项目。
    # next()函数要和生成迭代器的iter()函数一起使用。
    x, y = next(iter(db))
    print('sample:', x.shape, y.shape, y)

    # 将图片可视化一下
    #这里通过第九步创建的denormalize函数,将x的normalize逆处理一下。
    viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))

    #十、一次加载batch个图片的功能
    #第一个参数:db对象
    #第二个参数:batch size为32
    #第三个参数:是否要打散shuffle,这里为True, 保证每次获取的batch个图片是随机获取的,db对象是根据idx来选取的,因此db是没有打乱顺序的。
    #第四个参数:num_workers设置cpu工作线程,8个线程,一次加载8张图片,batch为32,那么cpu只需加载32/8=4次,这样可以提高效率。
    loader = DataLoader(db, batch_size=32, shuffle=True,num_workers=8)

    for x,y in loader: #每行有8张图片,一共有4行,总共图片32张
        #参数norm:表示一行显示8张图片。
        viz.images(db.denormalize(x),nrow=8, win='batch', opts=dict(title = 'batch'))
        #将label也显示出来,这里label就是y
        #y是tensor类型,将y转换为numpy,在变换到string类型
        viz.text(str(y.numpy()), win='label', opts=dict(title = 'batch-y'))

        #每加载完一组batch就休息10秒
        time.sleep(10)

if __name__ == '__main__':
    main()

每行有8张图片,一共有4行,总共图片32张。
在这里插入图片描述

通过上面10步就完成了自定义数据集的加载和显示。

十一、除了上面这种较为复杂的加载方法,还有通过API较为简便的加载自定义数据集(这种加载方式有严格要求,如果有其他额外要求,就需要用前十步的方式加载数据)

pokemon.py

import torch
import os, glob
import random, csv

# 所有自定义数据集的一个母类
from torch.utils.data import Dataset, DataLoader

# 常用的图片变换器
from torchvision import transforms
# 从图片读取出数据
from PIL import Image


# 自定义数据集的类,继承自Dataset
class Pokemon(Dataset):
    # 一、初始化函数init
    # 第一个参数root:总的图片所在的位置,可以是任意的位置,我们的图片可以放在任意的位置,我们这里就存储在当前目录pokemon文件夹下。
    # 第二个参数resize:图片输出的size,是由这个参数所进行设定。
    # 第三个参数mode:这里我们需要做train、validation以及test,对应这三种数据结构,因此我们用一个list[0,1,2]来代表是哪个模式。
    def __init__(self, root, resize, mode):
        # 先调用母类的初始化函数:
        super(Pokemon, self).__init__()
        # 1、首先我们将这个参数保存下来
        self.root = root
        self.resize = resize

        # 2、给每一个分类做一个映射,即当前的皮卡丘、妙蛙种子等这个string类型所对应的label是多少,这个是需要我们人为进行编码的。
        self.name2label = {}  # 用字典来表示映射关系

        # 通过循环方式,将root路径下的文件夹名进行编码
        for name in sorted(os.listdir(os.path.join(root))):
            # 过滤掉非文件夹:如果不是dir,就过滤掉,此外我们还通过sorted排序的方法,将键值对关系固定下来
            if not os.path.isdir(os.path.join(root, name)):
                continue
            # 文件名做key,当前name2label的长度做value
            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)

        # 将self.load_csv的返回值images, labels赋予self.images, self.labels
        self.images, self.labels = self.load_csv('images.csv')

        # 四、不同比例模式下对图片数量进行划分
        if mode == 'train':  # 取60%做training
            # len(self.images)的长度是1167,取60%做为train模式的图片
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'val':  # 取20%做validation, 60%-80%
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:  # mode为test,取80%到最末尾
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

    # 二、创建一个csv,用于保存图片全路径和对应的标签label
    # 这个函数接受一个参数filename
    # 这个函数中需要将所有图片都load进来
    def load_csv(self, filename):
        # 需要一个判断,如果文件不存在,就需要创建csv,直接读取创建好的csv文件内容即可:
        # 如果不存在,就需要创建csv
        if not os.path.exists(os.path.join(self.root, filename)):

            images = []
            for name in self.name2label.keys():
                # 类别信息我们可以使用路径来判断
                # 比如:'pokemon\\mewtwo\\00001.png'
                # 上面路径的mewtwo就是类别
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            # 1167张, 'pokemon\\bulbasaur\\00000000.png'
            print(len(images), images)

            # 将images顺序打乱
            random.shuffle(images)

            # 打开这个文件
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                # 新建writer,写入csv这个文件对象
                writer = csv.writer(f)
                for img in images:  # 获得每行信息'pokemon\\bulbasaur\\00000000.png'
                    # 通过分割符,将每行信息的内容分割开,取导数第二个,类型
                    name = img.split(os.sep)[-2]

                    # 通过获取的类型名来获取label
                    label = self.name2label[name]

                    # 将这个label信息写到csv中
                    # csv是以逗号作为分割的
                    # 形式为:'pokemon\\bulbasaur\\00000000.png', 0
                    writer.writerow([img, label])
                print('writen into csv file:', filename)

        # 三、读取csv文件过程:
        # 这里需要在开头有一个判断,如果csv存在,就不用写入csv了,直接进行读取
        # 下次运行的时候只需加载进来即可
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            # 新建reader,读取csv这个文件对象
            reader = csv.reader(f)
            for row in reader:
                img, label = row  # 解包出来:'pokemon\\bulbasaur\\00000000.png', 0
                label = int(label)  # 将这个label转码为int类型
                # 将img每个图片路径,以及label保存在建立好的列表对象中。
                images.append(img)
                labels.append(label)
        print(images[:4])
        print(labels[:4])
        print('len(images):', len(images))
        print('len(labels):', len(labels))
        assert len(images) == len(labels)
        print('read csv file:', filename)
        return images, labels

    # 完成两个自定义的逻辑:
    # 1、样本的总体数量(图片总体数量),返回的是一个数字,总体图片大概有1168张,60%用于training,因此返回6-7百张图片
    # 五、完成总体样本数量函数的内容
    def __len__(self):
        # 这里的样本长度是跟模型类别来决定的,上面已经根据不同模型类型划分了样本数量了。
        # 不同模式下,样本长度是不同的。
        # 因此这里的总体样本长度,就是不同模式下的样本数量。
        return len(self.images)

    # 九、解决normalize处理后,visdom无法正常显示的问题
    # 这里传入的参数x是normalize过后的
    def denormalize(self, x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # normalize的过程:
        # x_hat = (x-mean)/std
        # x = x_hat*std+mean 这个就是逆操作过程
        # x = [c,h,w]
        # mean = [3] => boardcasting后
        # unsqueeze()函数起升维的作用,参数表示在哪个地方加一个维度
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        print('mean.shape,std.shape:',mean.shape,std.shape)
        x = x_hat * std + mean
        return x

    # 2、用于返回当前index上面元素的值,这里是返回两个数据:
    # 需要返回当前image的data,以及image所对应的label[0,1,2,3,4]
    # 六、完成index与样本的一一对应
    def __getitem__(self, idx):
        # idx数值范围是[0-len(images)]
        # self.images保存了所有的数据;self.labels保存了所有数据对应的label信息;
        # img是一个string类型(还不是具体的图片,只是路径):'pokemon\\bulbasaur\\00000000.png'
        # label是一个整数类型
        img, label = self.images[idx], self.labels[idx]

        # 这里就需要将img所对应的路径读取出图片,并转为tensor类型
        # 这里我们可以Compose组合操作步骤
        # 八、增加数据预处理的工作,在Compose中增加这些内容,data augmentation数据增强
        # 这里我们做放大、旋转、裁切这三个数据增强的操作
        tf = transforms.Compose([
            # 这里需要将路径变成具体的图片数据类型
            # 即:string path => image data
            lambda x: Image.open(x).convert('RGB'),
            # Resize工作,这里的size是我们实例化时的self.resize的值
            # 1、data augmentation放大:在Resize设置的基础上,稍微调大一些size, 调整为1.25倍
            transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
            # 2、data augmentation旋转:增加随机旋转,注意:这里旋转角度不能太大,会增加学习的难度。
            transforms.RandomRotation(15),
            # 3、data augmentation中心裁切:裁切为我们所需要的大小
            transforms.CenterCrop(self.resize),
            # 将数据变为tensor类型
            transforms.ToTensor(),
            # 4、normalize处理,希望图片数值范围在0左右分布,而不希望数值只分布在0的右侧或只在左侧
            # 其中参数统计的所有image net数据集几百万张图片的mean=[R的mean,G的mean,B的mean]和std=[R的方差,G的方差,B的方差]
            # 基本上这个数值是通用的
            # 数据通过Normalize处理后,就是在-1到1之间分布了。
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        img = tf(img)
        label = torch.tensor(label)
        return img, label


# 创建一个调试函数:
def main():
    # 七、验证自定义数据集
    # 验证需要一些辅助函数,用visdom做一些可视化。
    import visdom
    import time
    import torchvision #通过API较为简便的加载自定义数据集,需要引入torchvision

    # 创建一个visdom这个对象
    viz = visdom.Visdom()

    # 十一、通过API较为简便的加载自定义数据集(前提是数据集按照不同类型存储在对应类型命名的文件夹下面,并且这些不同类别的文件夹都存储在统一的一个文件夹下,只有这种固定的二级目录存储形式才能用这个API进行加载。)
    tf = transforms.Compose([
        transforms.Resize((64,64)),
        transforms.ToTensor()
    ])
    #参数1:传入路径
    #参数2:变换器,这个变换器就是进行resize操作
    db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
    loader = DataLoader(db, batch_size=32, shuffle=True)

    print(db.class_to_idx) #通过这个就能知道不同类别是如何编码的了。

    for x,y in loader: #每行有8张图片,一共有4行,总共图片32张
        #参数norm:表示一行显示8张图片。
        viz.images(x,nrow=8, win='batch', opts=dict(title = 'batch'))
        #将label也显示出来,这里label就是y
        #y是tensor类型,将y转换为numpy,在变换到string类型
        viz.text(str(y.numpy()), win='label', opts=dict(title = 'batch-y'))

        #每加载完一组batch就休息10秒
        time.sleep(10)

    # # ==============注释掉这些内容============
    # db = Pokemon('pokemon', 64, 'train')
    #
    # # 首先是可视化一个样本
    # # next() 返回迭代器的下一个项目。
    # # next()函数要和生成迭代器的iter()函数一起使用。
    # x, y = next(iter(db))
    # print('sample:', x.shape, y.shape, y)
    #
    # # 将图片可视化一下
    # #这里通过第九步创建的denormalize函数,将x的normalize逆处理一下。
    # viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))
    #
    # #十、一次加载batch个图片的功能
    # #第一个参数:db对象
    # #第二个参数:batch size为32
    # #第三个参数:是否要打散shuffle,这里为True, 保证每次获取的batch个图片是随机获取的,db对象是根据idx来选取的,因此db是没有打乱顺序的。
    # #第四个参数:num_workers设置cpu工作线程,8个线程,一次加载8张图片,batch为32,那么cpu只需加载32/8=4次,这样可以提高效率。
    # loader = DataLoader(db, batch_size=32, shuffle=True,num_workers=8)
    #
    #
    # for x,y in loader: #每行有8张图片,一共有4行,总共图片32张
    #     #参数norm:表示一行显示8张图片。
    #     viz.images(db.denormalize(x),nrow=8, win='batch', opts=dict(title = 'batch'))
    #     #将label也显示出来,这里label就是y
    #     #y是tensor类型,将y转换为numpy,在变换到string类型
    #     viz.text(str(y.numpy()), win='label', opts=dict(title = 'batch-y'))
    #
    #     #每加载完一组batch就休息10秒
    #     time.sleep(10)
    # #===================================


if __name__ == '__main__':
    main()

在这里插入图片描述
在这里插入图片描述

Step2.创建分类器模型

对原来的ResNet18的模型进行稍微的修改:
文件名:resnet.py

import torch
from torch import nn
from torch.nn import functional as F

class ResBlk(nn.Module):
    """
    resnet block
    """
    def __init__(self, ch_in, ch_out, stride=1):
        """

        :param ch_in:
        :param ch_out:
        """
        super(ResBlk, self).__init__() #调用这个类的初始化方法,来初始化这个父类。

        #1、构建两个convolution单元
        #这里设置一个stride,用于减少数据量,衰减长和宽。
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        # Batch Normalization的目的是使我们的一批(Batch)feature map满足均值为0,方差为1的分布规律。
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        #2、创建短接调整输入维度模块
        #先判断一下输入的channel与输出的channel是不是不相同
        #网络结构写在Sequential中,可以方便的组织网络结构。
        #这部分是短接的额外单元extra module
        if ch_out != ch_in:
            #[b,ch_in,h,w] => [b,ch_out,h,w]
            #这里使用的是1*1的卷积单元,为了是只见ch_in改变为ch_out,其他size都不变。
            #此外还要确保参数stride与最开始第一次卷积的stride保持一致。这样长和宽才是一致的。
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm2d(ch_out)
            )
        else:
            #这里是相同的情况,相同的话就什么也不做。
            self.extra = nn.Sequential()

    #3、编写forward函数
    def forward(self, x):
        """

        :param x: [b,ch,h,w]
        :return:
        """
        #4、调用init构建好的两个convolution单元,并在其中添加一个relu激活函数
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        #5、编写short cut部分
        #上面卷积层的结果,需要和最开始的输入进行相加操作。
        #这里需要注意的是,输出与输入的维度相同才可以进行矩阵相加
        #所以我们给短接部分添加一个模块,这个模块用于调整输入的维度,使其能够和输出相加。
        #这部分叫extra module
        #如果ch_in != ch_out
        #通过 extra module: 将[b,ch_in,h,w] =>变为 [b,ch_out,h,w]
        # element.wise add 矩阵各个元素相加:
        out = x = self.extra(x) + out
        out = F.relu(out)
        return out

class ResNet18(nn.Module):

    def __init__(self,num_class):
        super(ResNet18, self).__init__()

        #1、先创建一个卷积层,将ch_in=3转为ch_out=16。
        #这里的stride设置为3
        #padding设置为0
        self.conv1 = nn.Sequential(
            nn.Conv2d(3,16,kernel_size=3,stride=2,padding=0),
            nn.BatchNorm2d(16)

        )

        # 紧接着4个ResBlk部分followed 4 block
        # [b,16,h,w] => [b,32,h,w]
        self.blk1 = ResBlk(16, 32,stride=2)

        # [b,32,h,w] => [b,64,h,w]
        self.blk2 = ResBlk(32, 64, stride=2)

        # [b,64,h,w] => [b,128,h,w]
        self.blk3 = ResBlk(64, 128,stride=2)

        # [b,128,h,w] => [b,256,h,w]
        self.blk4 = ResBlk(128, 256,stride=2)

        #跟一个线性层linear,变成10类
        self.outlayer = nn.Linear(256*1*1, num_class)

    def forward(self,x):
        """

        :param x:
        :return:
        """

        #将self.conv1(x)经过激活函数处理
        x = F.relu(self.conv1(x))
        # print('经过初始卷积层和relu:',x.shape)

        #经过4个block
        # [b,64,h,w] => [b,1024,h,w]
        x = self.blk1(x)
        # print('经过第一个blk1:',x.shape)

        x = self.blk2(x)
        # print('经过第二个blk2:',x.shape)

        x = self.blk3(x)
        # print('经过第三个blk3:',x.shape)

        x = self.blk4(x)
        # print('经过第四个blk4:', x.shape)

        # print('after conv:', x.shape)  # [b, 512, 2, 2]

        #[b,512,h,w] => [b,512,1,1]
        # x = F.adaptive_avg_pool2d(x,[1,1])
        # print('经过adaptive_avg_pool2d:', x.shape)

        #取第一个维度,即batch维度,其他维度相乘512*1*1
        x = x.view(x.size(0), -1)
        # print('经过view:',x.shape) #torch.Size([2, 512])

        # 跟一个线性层linear,变成10类
        x = self.outlayer(x)

        return x

def main():
    # 这里我们希望长和宽减半,减少数据量,因为如果不变的话,channel维度不断增加,会导致数据量翻倍。
    #所以我们设置stride步长,让长宽减小
    blk = ResBlk(64,128, stride=1)
    # print(blk) #查看网络结构
    # 创建一个假的数据集
    tmp = torch.randn(2,64, 224, 224)
    out = blk(tmp)
    print('block shape:', out.shape) #block shape: torch.Size([2, 128, 8, 8])
    # print('block:',out)

    #创建一个假的数据集
    #这里测试的目的是查看是否报错,如果报错,说明网络中的shape存在不匹配match的情况。
    #如果match,就不会报错,如果不match,就会报错。
    x = torch.randn(2,3,64,64)
    # print('初始数据输入:',x.shape)
    model = ResNet18(5)
    out = model(x)
    print('resnet18结果:',out.shape)
	
	#可以查看参数量
    p = sum(map(lambda p:p.numel(), model.parameters()))
    print('parameters size:', p)

if __name__ == '__main__':
    main()

在这里插入图片描述

这里对于x = torch.randn(2,3,224,224),在第148行 ,我们将原来的64,修改为224。
在这里插入图片描述
会发现报错了:

import torch
from torch import nn
from torch.nn import functional as F

class ResBlk(nn.Module):
    """
    resnet block
    """
    def __init__(self, ch_in, ch_out, stride=1):
        """

        :param ch_in:
        :param ch_out:
        """
        super(ResBlk, self).__init__() #调用这个类的初始化方法,来初始化这个父类。

        #1、构建两个convolution单元
        #这里设置一个stride,用于减少数据量,衰减长和宽。
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        # Batch Normalization的目的是使我们的一批(Batch)feature map满足均值为0,方差为1的分布规律。
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        #2、创建短接调整输入维度模块
        #先判断一下输入的channel与输出的channel是不是不相同
        #网络结构写在Sequential中,可以方便的组织网络结构。
        #这部分是短接的额外单元extra module
        if ch_out != ch_in:
            #[b,ch_in,h,w] => [b,ch_out,h,w]
            #这里使用的是1*1的卷积单元,为了是只见ch_in改变为ch_out,其他size都不变。
            #此外还要确保参数stride与最开始第一次卷积的stride保持一致。这样长和宽才是一致的。
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm2d(ch_out)
            )
        else:
            #这里是相同的情况,相同的话就什么也不做。
            self.extra = nn.Sequential()

    #3、编写forward函数
    def forward(self, x):
        """

        :param x: [b,ch,h,w]
        :return:
        """
        #4、调用init构建好的两个convolution单元,并在其中添加一个relu激活函数
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        #5、编写short cut部分
        #上面卷积层的结果,需要和最开始的输入进行相加操作。
        #这里需要注意的是,输出与输入的维度相同才可以进行矩阵相加
        #所以我们给短接部分添加一个模块,这个模块用于调整输入的维度,使其能够和输出相加。
        #这部分叫extra module
        #如果ch_in != ch_out
        #通过 extra module: 将[b,ch_in,h,w] =>变为 [b,ch_out,h,w]
        # element.wise add 矩阵各个元素相加:
        out = x = self.extra(x) + out
        out = F.relu(out)
        return out

class ResNet18(nn.Module):

    def __init__(self,num_class):
        super(ResNet18, self).__init__()

        #1、先创建一个卷积层,将ch_in=3转为ch_out=16。
        #这里的stride设置为3
        #padding设置为0
        self.conv1 = nn.Sequential(
            nn.Conv2d(3,16,kernel_size=3,stride=2,padding=0),
            nn.BatchNorm2d(16)

        )

        # 紧接着4个ResBlk部分followed 4 block
        # [b,16,h,w] => [b,32,h,w]
        self.blk1 = ResBlk(16, 32,stride=2)

        # [b,32,h,w] => [b,64,h,w]
        self.blk2 = ResBlk(32, 64, stride=2)

        # [b,64,h,w] => [b,128,h,w]
        self.blk3 = ResBlk(64, 128,stride=2)

        # [b,128,h,w] => [b,256,h,w]
        self.blk4 = ResBlk(128, 256,stride=2)

        #跟一个线性层linear,变成5类
        self.outlayer = nn.Linear(256*1*1, num_class)

    def forward(self,x):
        """

        :param x:
        :return:
        """

        #将self.conv1(x)经过激活函数处理
        x = F.relu(self.conv1(x))
        # print('经过初始卷积层和relu:',x.shape)

        #经过4个block
        # [b,64,h,w] => [b,1024,h,w]
        x = self.blk1(x)
        # print('经过第一个blk1:',x.shape)

        x = self.blk2(x)
        # print('经过第二个blk2:',x.shape)

        x = self.blk3(x)
        # print('经过第三个blk3:',x.shape)

        x = self.blk4(x)
        # print('经过第四个blk4:', x.shape)

        # print('after conv:', x.shape)  # [b, 512, 2, 2]

        #[b,512,h,w] => [b,512,1,1]
        # x = F.adaptive_avg_pool2d(x,[1,1])
        # print('经过adaptive_avg_pool2d:', x.shape)

        #取第一个维度,即batch维度,其他维度相乘512*1*1
        x = x.view(x.size(0), -1)
        # print('经过view:',x.shape) #torch.Size([2, 512])

        # 跟一个线性层linear,变成10类
        x = self.outlayer(x)

        return x

def main():
    # 这里我们希望长和宽减半,减少数据量,因为如果不变的话,channel维度不断增加,会导致数据量翻倍。
    #所以我们设置stride步长,让长宽减小
    blk = ResBlk(64,128)
    # print(blk) #查看网络结构
    # 创建一个假的数据集
    tmp = torch.randn(2,64, 224, 224)
    out = blk(tmp)
    print('block shape:', out.shape) #block shape: torch.Size([2, 128, 8, 8])
    # print('block:',out)

    #创建一个假的数据集
    #这里测试的目的是查看是否报错,如果报错,说明网络中的shape存在不匹配match的情况。
    #如果match,就不会报错,如果不match,就会报错。
    x = torch.randn(2,3,224,224) #这里我们将原来的64,修改为224
    # print('初始数据输入:',x.shape)
    model = ResNet18(5)
    out = model(x)
    print('resnet18结果:',out.shape)

    p = sum(map(lambda p:p.numel(), model.parameters()))
    print('parameters size:', p)

if __name__ == '__main__':
    main()

在这里插入图片描述
这个错误说明这个模型与输入x之间不是很匹配。

因此需要调整self.blk4与self.outlayer之间的输出与输入关系:
经过第四个blk4: torch.Size([2, 256, 7, 7]),我们希望参数量少一些,因此将conv1、blk1和blk1的stride设置的大一些,这样就可以让参数量少一些。我将Linear中的输入调整一下为25633。
这样将self.blk4与self.outlayer之间进行匹配后,发现是可以正常运行的了:

import torch
from torch import nn
from torch.nn import functional as F

class ResBlk(nn.Module):
    """
    resnet block
    """
    def __init__(self, ch_in, ch_out, stride=1):
        """

        :param ch_in:
        :param ch_out:
        """
        super(ResBlk, self).__init__() #调用这个类的初始化方法,来初始化这个父类。

        #1、构建两个convolution单元
        #这里设置一个stride,用于减少数据量,衰减长和宽。
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        # Batch Normalization的目的是使我们的一批(Batch)feature map满足均值为0,方差为1的分布规律。
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        #2、创建短接调整输入维度模块
        #先判断一下输入的channel与输出的channel是不是不相同
        #网络结构写在Sequential中,可以方便的组织网络结构。
        #这部分是短接的额外单元extra module
        if ch_out != ch_in:
            #[b,ch_in,h,w] => [b,ch_out,h,w]
            #这里使用的是1*1的卷积单元,为了是只见ch_in改变为ch_out,其他size都不变。
            #此外还要确保参数stride与最开始第一次卷积的stride保持一致。这样长和宽才是一致的。
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm2d(ch_out)
            )
        else:
            #这里是相同的情况,相同的话就什么也不做。
            self.extra = nn.Sequential()

    #3、编写forward函数
    def forward(self, x):
        """

        :param x: [b,ch,h,w]
        :return:
        """
        #4、调用init构建好的两个convolution单元,并在其中添加一个relu激活函数
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        #5、编写short cut部分
        #上面卷积层的结果,需要和最开始的输入进行相加操作。
        #这里需要注意的是,输出与输入的维度相同才可以进行矩阵相加
        #所以我们给短接部分添加一个模块,这个模块用于调整输入的维度,使其能够和输出相加。
        #这部分叫extra module
        #如果ch_in != ch_out
        #通过 extra module: 将[b,ch_in,h,w] =>变为 [b,ch_out,h,w]
        # element.wise add 矩阵各个元素相加:
        out = x = self.extra(x) + out
        out = F.relu(out)
        return out

class ResNet18(nn.Module):

    def __init__(self,num_class):
        super(ResNet18, self).__init__()

        #1、先创建一个卷积层,将ch_in=3转为ch_out=16。
        #这里的stride设置为3
        #padding设置为0
        self.conv1 = nn.Sequential(
            nn.Conv2d(3,16,kernel_size=3,stride=3,padding=0),
            nn.BatchNorm2d(16)

        )

        # 紧接着4个ResBlk部分followed 4 block
        # [b,16,h,w] => [b,32,h,w]
        self.blk1 = ResBlk(16, 32,stride=3)

        # [b,32,h,w] => [b,64,h,w]
        self.blk2 = ResBlk(32, 64, stride=3)

        # [b,64,h,w] => [b,128,h,w]
        self.blk3 = ResBlk(64, 128,stride=2)

        # [b,128,h,w] => [b,256,h,w]
        self.blk4 = ResBlk(128, 256,stride=2)

        #跟一个线性层linear,变成10类
        #这里是[b, 256, 7, 7] ,我们希望参数量少一些,因此将conv1、blk1和blk1的stride设置的大一些,这样就可以让参数量少一些。
        #我将Linear中的输入调整一下为256*3*3
        self.outlayer = nn.Linear(256*3*3, num_class)

    def forward(self,x):
        """

        :param x:
        :return:
        """

        #将self.conv1(x)经过激活函数处理
        x = F.relu(self.conv1(x))
        # print('经过初始卷积层和relu:',x.shape)

        #经过4个block
        # [b,64,h,w] => [b,1024,h,w]
        x = self.blk1(x)
        # print('经过第一个blk1:',x.shape)

        x = self.blk2(x)
        # print('经过第二个blk2:',x.shape)

        x = self.blk3(x)
        # print('经过第三个blk3:',x.shape)

        x = self.blk4(x)
        print('经过第四个blk4:', x.shape) #经过4个ResBlk后x的情况。
        #经过第四个blk4: torch.Size([2, 256, 7, 7])

        # print('after conv:', x.shape)  # [b, 512, 2, 2]

        #[b,512,h,w] => [b,512,1,1]
        # x = F.adaptive_avg_pool2d(x,[1,1])
        # print('经过adaptive_avg_pool2d:', x.shape)

        #取第一个维度,即batch维度,其他维度相乘512*1*1
        x = x.view(x.size(0), -1)
        # print('经过view:',x.shape) #torch.Size([2, 512])

        # 跟一个线性层linear,变成10类
        x = self.outlayer(x)

        return x

def main():
    # 这里我们希望长和宽减半,减少数据量,因为如果不变的话,channel维度不断增加,会导致数据量翻倍。
    #所以我们设置stride步长,让长宽减小
    blk = ResBlk(64,128)
    # print(blk) #查看网络结构
    # 创建一个假的数据集
    tmp = torch.randn(2,64, 224, 224)
    out = blk(tmp)
    print('block shape:', out.shape) #block shape: torch.Size([2, 128, 8, 8])
    # print('block:',out)

    #创建一个假的数据集
    #这里测试的目的是查看是否报错,如果报错,说明网络中的shape存在不匹配match的情况。
    #如果match,就不会报错,如果不match,就会报错。
    x = torch.randn(2,3,224,224) #这里我们将原来的64,修改为224
    # print('初始数据输入:',x.shape)
    model = ResNet18(5)
    out = model(x)
    print('resnet18结果:',out.shape)
    #resnet18结果: torch.Size([2, 5])

    #这一步就是求模型中的参数量:
    #model.parameters()获取模型中的每一个参数,p.numel()每个参数所占用内存的大小,并sum求和。
    #map() 会根据提供的函数对指定序列做映射。map(function, iterable, ...)
    #每一个参数的参数量进行相加
    p = sum(map(lambda p:p.numel(), model.parameters()))
    print('parameters size:', p)

if __name__ == '__main__':
    main()

在这里插入图片描述
这里参数量1234885属于中等大小,现在上十亿的参数量也是比较常见的。

这样网络的参数就完成了。

Step3.Train and Test

文件名:train_scratch.py

import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader

from pokemon import Pokemon
from resnet_test import ResNet18

batchsz = 32
lr = 1e-3
epochs = 10

device = torch.device('cuda')
torch.manual_seed(1234) #这个是随机数种子,保证每次都能复现出来。

#这里是需要实例化Pokemon类
#这里之所以使用224,是因为是ResNet最适合的大小。
train_db = Pokemon('pokemon',224,'train')
val_db = Pokemon('pokemon',224,'val')
test_db = Pokemon('pokemon',224,'test')

#批量加载数据
#参数num_workers表示工作线程数:
train_loader = DataLoader(train_db
                          ,batch_size=batchsz
                          ,shuffle=True
                          ,num_workers=4)

val_loader = DataLoader(val_db
                        , batch_size=batchsz
                        , num_workers=2)

test_loader = DataLoader(test_db
                         , batch_size=batchsz
                         , num_workers=2)

#需要把train的进度保存下来,需要用到visdom
viz = visdom.Visdom()

#建立一个测试函数:测试函数针对validation和test功能是一样的
def evalute(model,loader):
    #用于统计总的预测正确的数量
    correct = 0
    #总的测试数量
    total = len(loader.dataset)
    for x,y in loader:
        x,y = x.to(device),y.to(device)
        with torch.no_grad():#test和validation是不需要梯度信息的
            logits = model(x)
            pred = logits.argmax(dim=1) #最大的值所在的位置
        #总的预测正确的数量,累加操作
        correct += torch.eq(pred,y).sum().float().item()
    accuracy = correct/ total
    return accuracy

def main():
    #实例化模型
    model = ResNet18(5).to(device)
    #创建一个优化器Adam,这个优化器比较好
    optimizer = optim.Adam(model.parameters(),lr=lr)

    #Loss的计算方法:CrossEntropyLoss;
    #这个Loss所接受的参数是logits,logits是不需要经过一个softmax的,只需要得到logits即可。
    criteon = nn.CrossEntropyLoss()

    #用于保存模型的训练状态
    best_acc, best_epoch = 0,0

    #step每次都是从0开始的,因此这里我们创建一个全局step
    global_step = 0

    #用visdom工具保存下accuracy和loss
    #training和loss的曲线
    #x=0,y=-1是初始状态
    viz.line([0],[-1],win='loss',opts = dict(title='loss'))
    # training和validation accuracy的曲线
    viz.line([0],[-1],win='val_acc',opts = dict(title='val_acc'))

    #training逻辑
    for epoch in range(epochs):
        for step,(x,y) in enumerate(train_loader):
            #x:[b,3,224,224]; y:[b]
            x,y = x.to(device), y.to(device) #x和y都转移到cuda上面

            #执行forward函数
            logits = model(x) #学出的预测结果
            #在pytorch中crossEntropyLoss中,传入的真实值y不需要进行one-hot操作,不需要做one-hot编码,会在内部做one-hot。
            #所以我们直接传入y就可以了。
            loss = criteon(logits,y) #预测结果与真实值进行交叉熵计算

            #前向传播和迭代过程
            #优化器
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 用visdom工具保存下accuracy和loss
            # 每一个step我都要记录下来
            # validation和loss的曲线
            # x=loss.item()loss是一个tensor,因此需要通过item转为具体数值,y=-1是初始状态
            #参数update为append,表示添加到曲线的末尾。
            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1

        #这里我们每完成两个epoch就做一组validation
        if epoch % 1 == 0:
            #我们根据validation accuracy来选择要不要保存这个模型的训练状态。
            val_acc = evalute(model , val_loader)
            #如果当前accuracy大于best_acc,就保存当前的状态:
            if val_acc>best_acc:
                best_epoch = epoch
                best_acc = val_acc
                #保存当前模型的状态:
                #参数一:模型状态值
                #参数二:模型状态保存的文件名,文件名后缀随意
                torch.save(model.state_dict(),'best.mdl')

                # validation和 accuracy的曲线
                # 这里val_acc是数值型,所以不需要转换。
                viz.line([val_acc], [global_step], win='val_acc', update='append')

    print('best acc:',best_acc,'best epoch:',best_epoch)

    #从最好的状态加载模型:
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from check point!')

    #上面加载了最好的模型状态,这里使用的模型也是最好的状态时的模型
    test_acc = evalute(model, test_loader)
    print('test_acc:',test_acc)




if __name__ == '__main__':
    main()

在这里插入图片描述
在这里插入图片描述
best.mdl这个文件就是我们保存的最好模型状态的信息:
在这里插入图片描述
我们发现test_acc: 0.8247863247863247,这个结果并不是很好,我们如何将这个结果提升上来呢?

Step4. Transfer learning迁移学习

这里我们注意到数据量太少了,对于我们的模型进行training来说是不够的。因此就很容易出现overfitting的情况。

这里我们训练的图片是Pokemon,和ImageNet的图片中某些类型图片存在一定的重合,存在某些共性knowledge(即左下角的图片分区的情况)。
在这里插入图片描述
既然存在这种情况,我们能不能用ImageNet的重合类型图片帮助我们来进行训练呢?是可以的,在A任务上train好一个分类器,再transfer到B上去。迁移学习的目的是在获取一定的额外数据或者是存在一个已有的模型的前提下,将其应用在新的且有一定相关性的task。

transfer learning具体如何实现呢?

在ImageNet上training好的网络模型,即common knowledge。我们利用common knowledge把前面的这些knowledge保持不动,并将最后一层去掉。(即我们transfer的是A部分,B这部分我们不transfer)。这里相对于从随机初始化来说,我们是从有良好knowledge的网络状态开始初始化training的,通过良好的初始化与新的分类模型结合,这样模型性能就远远好于之前training scratch(从零开始)。
在这里插入图片描述
utils.py 用于打平操作

from    matplotlib import pyplot as plt
import  torch
from    torch import nn

#该函数是一个标准的打平层
class Flatten(nn.Module):
    #该文件utils包含一些辅助函数。
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

#该函数是将img打印到matplotlib上
def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

train_transfer.py:

import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader

from pokemon import Pokemon

#这里我们将ResNet18注释掉,加载已经training好的ResNet18模型
# from resnet_test import ResNet18
from torchvision.models import resnet18 #这个resnet18是已经training好的状态

from utils import Flatten #用于打平,这个是自己来实现的打平层

batchsz = 32
lr = 1e-3
epochs = 10

device = torch.device('cuda')
torch.manual_seed(1234) #这个是随机数种子,保证每次都能复现出来。

#这里是需要实例化Pokemon类
#这里之所以使用224,是因为是ResNet最适合的大小。
train_db = Pokemon('pokemon',224,'train')
val_db = Pokemon('pokemon',224,'val')
test_db = Pokemon('pokemon',224,'test')

#批量加载数据
#参数num_workers表示工作线程数:
train_loader = DataLoader(train_db
                          ,batch_size=batchsz
                          ,shuffle=True
                          ,num_workers=4)

val_loader = DataLoader(val_db
                        , batch_size=batchsz
                        , num_workers=2)

test_loader = DataLoader(test_db
                         , batch_size=batchsz
                         , num_workers=2)

#需要把train的进度保存下来,需要用到visdom
viz = visdom.Visdom()

#建立一个测试函数:测试函数针对validation和test功能是一样的
def evalute(model,loader):
    #用于统计总的预测正确的数量
    correct = 0
    #总的测试数量
    total = len(loader.dataset)
    for x,y in loader:
        x,y = x.to(device),y.to(device)
        with torch.no_grad():#test和validation是不需要梯度信息的
            logits = model(x)
            pred = logits.argmax(dim=1) #最大的值所在的位置
        #总的预测正确的数量,累加操作
        correct += torch.eq(pred,y).sum().float().item()
    accuracy = correct/ total
    return accuracy

def main():
    #实例化模型
    #我们将原来的ResNet18注释掉,使用transfer learning来实现一下
    # model = ResNet18(5).to(device)
    #===============具体的transfer learning实现内容==============
    #使用已经训练好的resnet18模型,一定要设置这个参数pretrained=True
    trained_model = resnet18(pretrained=True)
    #我们要使用训练好的resnet18模型的A部分,即取出前17层:
    #Sequential结束的是一个打散的数据,所有我们在list前加一个*,*args:接收若干个位置参数,转换成元组tuple形式。
    model = nn.Sequential(*list(trained_model.children())[:-1] #model的前17层(即A部分)返回的结果是:[b,512,1,1]
                          ,Flatten() #打平操作从[b,512,1,1]=>[b,512]
                          ,nn.Linear(512,5) #这层是最后那层,用于从新学习分成5类。
                          ).to(device)
    #我们从已经训练好的resnet18开始训练效果会好很多

    # # 这里我们测试一下
    # x = torch.randn(2,3,224,224)
    # print(model(x).shape)#打印结果为:torch.Size([2, 5])
    # #这样就实现了transfer learning
    # ======================================================

    #创建一个优化器Adam,这个优化器比较好
    optimizer = optim.Adam(model.parameters(),lr=lr)

    #Loss的计算方法:CrossEntropyLoss;
    #这个Loss所接受的参数是logits,logits是不需要经过一个softmax的,只需要得到logits即可。
    criteon = nn.CrossEntropyLoss()

    #用于保存模型的训练状态
    best_acc, best_epoch = 0,0

    #step每次都是从0开始的,因此这里我们创建一个全局step
    global_step = 0

    #用visdom工具保存下accuracy和loss
    #training和loss的曲线
    #x=0,y=-1是初始状态
    viz.line([0],[-1],win='loss',opts = dict(title='loss'))
    # training和validation accuracy的曲线
    viz.line([0],[-1],win='val_acc',opts = dict(title='val_acc'))

    #training逻辑
    for epoch in range(epochs):
        for step,(x,y) in enumerate(train_loader):
            #x:[b,3,224,224]; y:[b]
            x,y = x.to(device), y.to(device) #x和y都转移到cuda上面

            #执行forward函数
            logits = model(x) #学出的预测结果
            #在pytorch中crossEntropyLoss中,传入的真实值y不需要进行one-hot操作,不需要做one-hot编码,会在内部做one-hot。
            #所以我们直接传入y就可以了。
            loss = criteon(logits,y) #预测结果与真实值进行交叉熵计算

            #前向传播和迭代过程
            #优化器
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 用visdom工具保存下accuracy和loss
            # 每一个step我都要记录下来
            # validation和loss的曲线
            # x=loss.item()loss是一个tensor,因此需要通过item转为具体数值,y=-1是初始状态
            #参数update为append,表示添加到曲线的末尾。
            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1

        #这里我们每完成两个epoch就做一组validation
        if epoch % 1 == 0:
            #我们根据validation accuracy来选择要不要保存这个模型的训练状态。
            val_acc = evalute(model , val_loader)
            #如果当前accuracy大于best_acc,就保存当前的状态:
            if val_acc>best_acc:
                best_epoch = epoch
                best_acc = val_acc
                #保存当前模型的状态:
                #参数一:模型状态值
                #参数二:模型状态保存的文件名,文件名后缀随意
                torch.save(model.state_dict(),'best.mdl')

                # validation和 accuracy的曲线
                # 这里val_acc是数值型,所以不需要转换。
                viz.line([val_acc], [global_step], win='val_acc', update='append')

    print('best acc:',best_acc,'best epoch:',best_epoch)

    #从最好的状态加载模型:
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from check point!')

    #上面加载了最好的模型状态,这里使用的模型也是最好的状态时的模型
    test_acc = evalute(model, test_loader)
    print('test_acc:',test_acc)




if __name__ == '__main__':
    main()

在这里插入图片描述
在这里插入图片描述
发现test_acc: 0.9401709401709402,有了明显的提升,效果是非常非常好的。

总结:

这次一共三方面内容:
1、Load custom data 自定义数据集
2、Train from scratch 从头开始学习
3、Transfer learning 迁移学习,在已有训练好的模型上开始学习

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch-10 自定义数据集实战(Load data自定义数据集、Build model创建一个模型、Train and Test、Transfer Learning迁移学习) 的相关文章

  • 删除通过pandas创建的html表格的边框

    我正在使用 python 脚本在网页上显示数据框 我用了df to html将我的数据框转换为 HTML 但是 默认情况下 它将边框设置为 0 我尝试通过自定义 css 模板来覆盖它 但它不起作用 这是我的熊猫代码 ricSubscript
  • 在Python中将字符串转换为字典或列表?

    在Python中将此字符串转换为列表或字典 u f i r s t n a m e u j o h n u l a s t n a m e u s m i t h u a g e 2 0 u m o b
  • 使用 python sqlalchemy 通过WITH语句执行原始查询

    我正在尝试使用原始 sqlalchemy 将值插入 Postgres11 数据库text 查询 当我通过 psql client 运行以下 SQL 查询时 它可以正常工作 WITH a AS INSERT INTO person id VA
  • 字幕重新格式化以完整句子结尾

    我有以下 srt 字幕 文件 import pysrt srt 01 00 02 14 000 gt 00 02 18 000 I understand how customers do their choice So 02 00 02 1
  • 如何在旧数据库中的 Django 中进行 INNER JOIN ?

    抱歉 我的问题可能很简单 但我是 Django 的新手 真的很困惑 我有一个丑陋的旧表 我无法更改 它有 2 个表 class Salespersons models Model id models IntegerField unique
  • 如何在Python中测量时间?

    我想启动我的程序 测量程序启动的时间 然后等待几秒钟 按下按钮 K RIGHT 并测量按下按钮的时间 我正在使用 Pygame 来注册 Keydown 但在我下面的代码中它没有注册我的 Keydown 我在这里做错了什么 start tim
  • facebook graph api 调用 python 中的 appsecret_proof

    在 python 中使用 appsecret proof 参数进行图形 api 调用的正确方法是什么 有没有允许这样的图书馆 我试图使用 python for facebook 库 但文档实际上不存在 所以我无法弄清楚 您可以使用以下方法来
  • 在Python中将SQL转换为json[重复]

    这个问题在这里已经有答案了 我需要传递一个可以使用它进行转换的对象 parseJSON 查询如下所示 cursor execute SELECT earnings date FROM table 为了传递可以转换为 json 的 HttpR
  • ValueError:展开时包装器循环

    我的示例代码中的 Python3 测试用例 文档测试 失败 但在 Python2 中同样可以正常工作 test py class Test object def init self a 0 self a a def getattr self
  • Networkx - 最短路径长度

    我在用着networkx管理由 50k 个节点组成的大型网络图 我想计算一组特定节点 例如 N 之间的最短路径长度 为此我正在使用nx shortest path length功能 在 N 的某些节点中可能没有路径 因此 networkx
  • 使用 imaplib 库连接到电子邮件时遇到 AUTHENTICATIONFAILED 错误

    如何连接到 imaplib 库而不遇到 AUTHENTICATIONFAILE 错误 通过网络浏览器登录时 我的 Gmail 收件箱显示严重的安全警报 登录尝试被阻止 IMAP SERVER imap gmail com USERNAME
  • 像多米诺骨牌一样对 Python 中的元组进行排序/查找顶点连接

    我有一个像这样的整数元组列表 L 1 2 7 6 2 3 8 5 3 8 5 7 每对定义两个顶点之间的边 我想找到顶点连接性 没有循环 元组总是像多米诺骨牌一样唯一地链接起来 因此在这种情况下 排序列表应如下所示 L sorted 1 2
  • Python 有限边界 Voronoi 单元

    我正在尝试改编我在 stackoverflow 上找到的代码来创建具有有限边界的 voronoi 单元 我发现下面的代码https stackoverflow com a 20678647 2443944 https stackoverfl
  • PyInstaller,规范文件,导入错误:没有名为“blah”的模块

    我正在尝试通过构建 python 脚本py安装程序 http www pyinstaller org 我使用以下命令来配置 生成规范文件并构建 wget pyinstaller zip extracted it python Configu
  • cython.parallel.prange 中的 cython 共享内存 - 块

    我有一个函数foo它以指向内存的指针作为参数 并写入和读取该内存 cdef void foo double data data some index int some value double do something dependent
  • 在unittest.main()之后执行命令

    我从另一个 Python 脚本调用以下脚本 测试 py 日志文件 它应该运行测试并将结果保存在日志文件中 但由于某种原因 之后的命令unittest main testRunner runner 没有被执行 我什至不确定文件写入后是否会关闭
  • Matplotlib:以数据坐标中给定的宽度绘制线条

    我试图弄清楚如何绘制具有数据单位宽度的线条 例如 在下面的代码片段中 我希望宽度为 80 的线的水平部分始终从 y 40 延伸到 y 40 标记 并且即使坐标系的限制也保持这种状态改变 有没有办法用 matplotlib 中的 Line2D
  • 一旦相关命令更改,如何自动运行 py.test?

    通过autonose或nosy 一旦某些测试文件或相关文件发生更改 它将自动运行nosetests 请问py test是否提供了类似的功能 有没有其他工具可以自动激发py test 您可以安装pytest xdist 插件 http pyp
  • 以任意深度嵌套 defaultdict

    我想嵌套任意数量的默认字典 如下所示 from collections import defaultdict D defaultdict lambda defaultdict int 正如所描述的那样工作正常earlier https st
  • 在 Raspberry Pi 4 上的多个输出设备上播放多个 mp3 文件

    我需要 4 8 个同时播放立体声音频音乐频道 连续播放 SD 卡上特定文件夹中的 mp3 音乐 Working 板载 3 5 音频插孔 USB声卡正常播放音乐 Problem 但一旦我尝试在树莓派上使用带有 USB 声卡的第三个音频输出 其

随机推荐