Pytorch加载自己的数据集(以图片格式的Mnist数据集为例)

2023-10-27


Pytorch加载自己的数据集(以图片格式的Mnist数据集为例)

前言

初学pytorch,看了很多教程,发现所有教程在加载数据集的时候都用的pytorch已经定义好的模块,没有详细讲到如何使用Dataset和DataLoader加载自己格式多样的数据集,经过一段时间研究,成功跑通以图片为训练数据集的简单分类模型,现记录如下。

数据集在这里:
链接: https://pan.baidu.com/s/16T1IoAgOsepLqFRzjDck3g?pwd=h254 提取码: h254 复制这段内容后打开百度网盘手机App,操作更方便哦


一、数据集转换

Mnist是非常经典的数据集之一,从官网下载得到的是二进制的文件,与我们常用的图片格式不符,所以先将二进制文件转换为图像。
在这里插入图片描述

转换代码如下:

# -*- coding: utf-8 -*-
import numpy as np
import struct
import os
import cv2


class DataUtils(object):
    def __init__(self, filename=None, outpath=None):
        self._filename = filename
        self._outpath = outpath

        self._tag = '>'  # 大端格式
        self._twoBytes = 'II'
        self._fourBytes = 'IIII'
        self._pictureBytes = '784B'
        self._labelByte = '1B'
        self._twoBytes2 = self._tag + self._twoBytes
        self._fourBytes2 = self._tag + self._fourBytes
        self._pictureBytes2 = self._tag + self._pictureBytes
        self._labelByte2 = self._tag + self._labelByte

        self._imgNums = 0
        self._LabelNums = 0

    def getImage(self):
        """
        将MNIST的二进制文件转换成像素特征数据
        """
        binfile = open(self._filename, 'rb')  # 以二进制方式打开文件
        buf = binfile.read()
        binfile.close()
        index = 0
        numMagic, self._imgNums, numRows, numCols = struct.unpack_from(self._fourBytes2, buf, index)
        index += struct.calcsize(self._fourBytes)
        images = []
        print('image nums: %d' % self._imgNums)
        for i in range(self._imgNums):
            imgVal = struct.unpack_from(self._pictureBytes2, buf, index)
            index += struct.calcsize(self._pictureBytes2)
            imgVal = list(imgVal)
            images.append(imgVal)
        return np.array(images), self._imgNums

    def getLabel(self):
        """
        将MNIST中label二进制文件转换成对应的label数字特征
        """
        binFile = open(self._filename, 'rb')
        buf = binFile.read()
        binFile.close()
        index = 0
        magic, self._LabelNums = struct.unpack_from(self._twoBytes2, buf, index)
        index += struct.calcsize(self._twoBytes2)
        labels = []
        for x in range(self._LabelNums):
            im = struct.unpack_from(self._labelByte2, buf, index)
            index += struct.calcsize(self._labelByte2)
            labels.append(im[0])
        return np.array(labels)

    def outImg(self, arrX, arrY, imgNums):
        """
        根据生成的特征和数字标号,输出图像
        """
        output_txt = self._outpath + '/img.txt'
        output_file = open(output_txt, 'a+')

        m, n = np.shape(arrX)
        # 每张图是28*28=784Byte
        for i in range(imgNums):
            img = np.array(arrX[i])
            img = img.reshape(28, 28)
            # print(img)
            outfile = str(i) + "_" + str(arrY[i]) + ".bmp"
            # print('saving file: %s' % outfile)

            txt_line = outfile + " " + str(arrY[i]) + '\n'
            output_file.write(txt_line)
            cv2.imwrite(self._outpath + '/' + outfile, img)
        output_file.close()


if __name__ == '__main__':
    # 二进制文件路径,需要修改,和自己的相对应
    trainfile_X = 'C:\\Users\\60058670\\Desktop\\MNIST\\train-images.idx3-ubyte'
    trainfile_y = 'C:\\Users\\60058670\\Desktop\\MNIST\\train-labels.idx1-ubyte'
    testfile_X = 'C:\\Users\\60058670\\Desktop\\MNIST\\t10k-images.idx3-ubyte'
    testfile_y = 'C:\\Users\\60058670\\Desktop\\MNIST\\t10k-labels.idx1-ubyte'

    # 加载mnist数据集
    train_X, train_img_nums = DataUtils(filename=trainfile_X).getImage()
    train_y = DataUtils(filename=trainfile_y).getLabel()
    test_X, test_img_nums = DataUtils(testfile_X).getImage()
    test_y = DataUtils(testfile_y).getLabel()

    # 以下内容是将图像保存到本地文件中
    path_trainset = "C:\\Users\\60058670\\Desktop\\MNIST\\train"
    path_testset = "C:\\Users\\60058670\\Desktop\\MNIST\\test"
    if not os.path.exists(path_trainset):
        os.mkdir(path_trainset)
    if not os.path.exists(path_testset):
        os.mkdir(path_testset)
    DataUtils(outpath=path_trainset).outImg(train_X, train_y, int(train_img_nums / 10))  # /10是只转换十分之一,用于测试
    DataUtils(outpath=path_testset).outImg(test_X, test_y, int(test_img_nums / 10))

二、构建自己的数据集

构建方法为继承Dataset类,用DataLoader加载

1.引入库

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

2.构建MnistDataset类

# 构建自己的数据集
class MnistDataset(Dataset):
    def __init__(self, transform=None, lu_jing=None):
        self.lu_jing = lu_jing
        self.数据 = os.listdir(self.lu_jing)
        self.transform = transform
        self.len = len(self.数据)

    def __getitem__(self, index):
        image_index = self.数据[index]
        img_path = os.path.join(self.lu_jing, image_index)
        img = Image.open(img_path)
        if self.transform:
            img = self.transform(img)

        label = int(image_index[-5])
        label = self.oneHot(label)
        return img, label

    def __len__(self):
        return self.len

    # 将标签转为onehot编码
    def oneHot(self, label):
        tem = np.zeros(10)
        tem[label] = 1
        return torch.from_numpy(tem)

3.搭建网络模型

只为演示,模型比较简单。

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.Conv1 = torch.nn.Conv2d(1, 10, kernel_size=(5, 5))
        self.Conv2 = torch.nn.Conv2d(10, 20, kernel_size=(5, 5))
        self.pool = torch.nn.MaxPool2d(2)
        self.fl = torch.nn.Linear(320, 10)

    def forward(self, x):
        bs = x.size(0)
        x = F.relu(self.pool(self.Conv1(x)))
        x = F.relu(self.pool(self.Conv2(x)))
        x = x.view(bs, -1)
        x = self.fl(x)
        return x

三 完整代码

import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
from torchvision import transforms
import torch.nn.functional as F


# 构建自己的数据集
class MnistDataset(Dataset):
    def __init__(self, transform=None, lu_jing=None):
        self.lu_jing = lu_jing
        self.数据 = os.listdir(self.lu_jing)
        self.transform = transform
        self.len = len(self.数据)

    def __getitem__(self, index):
        image_index = self.数据[index]
        img_path = os.path.join(self.lu_jing, image_index)
        img = Image.open(img_path)
        if self.transform:
            img = self.transform(img)

        label = int(image_index[-5])
        label = self.oneHot(label)
        return img, label

    def __len__(self):
        return self.len

    # 将标签转为onehot编码
    def oneHot(self, label):
        tem = np.zeros(10)
        tem[label] = 1
        return torch.from_numpy(tem)


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.Conv1 = torch.nn.Conv2d(1, 10, kernel_size=(5, 5))
        self.Conv2 = torch.nn.Conv2d(10, 20, kernel_size=(5, 5))
        self.pool = torch.nn.MaxPool2d(2)
        self.fl = torch.nn.Linear(320, 10)

    def forward(self, x):
        bs = x.size(0)
        x = F.relu(self.pool(self.Conv1(x)))
        x = F.relu(self.pool(self.Conv2(x)))
        x = x.view(bs, -1)
        x = self.fl(x)
        return x


if __name__ == '__main__':
    # 训练集路径
    train_data = "C:\\Users\\60058670\\Desktop\\MNIST\\train"
    transform = transforms.Compose([transforms.ToTensor()])  # 归一化处理
    data = MnistDataset(transform=transform, lu_jing=train_data)
    data_loader = DataLoader(data, batch_size=200, shuffle=True)  # 使用DataLoader加载数据
    model = Model()
    criterion = torch.nn.CrossEntropyLoss()  # 交叉熵损失
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)  # model.parameters()自动完成参数的初始化操作

    for epoch in range(20):
        for i, data1 in enumerate(data_loader, 0):  # train_loader 是先shuffle后mini_batch
            inputs, labels = data1
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
        if epoch % 5 == 0:
            print(epoch, loss.item())
    # 测试集路径
    test_data = 'C:\\Users\\60058670\\Desktop\\MNIST\\test'

    x_test = MnistDataset(transform=transform, lu_jing=test_data)
    x_test = DataLoader(x_test, batch_size=100, shuffle=False)  # 使用DataLoader加载数据
    total = 0
    correct = 0
    for i, data in enumerate(x_test, 0):  # train_loader 是先shuffle后mini_batch
        inputs, labels = data
        y_pred = model(inputs)
        _, labels = torch.max(labels.data, dim=1)
        _, predicted = torch.max(y_pred.data, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        print('accuracy on test set: {} % '.format(100 * correct / total))
    print(correct, total)

总结

纸上得来终觉浅,绝知此事要躬行。自己动手写了代码就会发现一堆问题,知识就是在解决问题的过程中积累的。初学不久,有问题大家可以一起交流讨论。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch加载自己的数据集(以图片格式的Mnist数据集为例) 的相关文章

随机推荐

  • Vue2项目中使用高德地图自定义(Marker)标记点和创建(MassMarks)海量标记点

    前言 本篇文章就是单独分享一下在Vue2项目中如何自定义创建marker标记点和针对要创建庞大数量标记点时所采用的API 能够快速创建数量庞大的marker 不至于在浏览器渲染时产生卡顿的现象 需要了解如何在Vue2项目中引入高德地图请参照
  • [c++/java]递归系列

    本系列是根据个人的做题总结出来 或许有不对之处 望给位大佬指出 同时这个系列也是长期的一个系列 每遇到一个优秀的递归题目 我都会添加上去 基本递归思路 递归之结束条件 个人认为在写递归之前应该首先考虑什么时候递归结束 或者是递归收敛于什么条
  • 有一台电脑怎么挣钱_游戏搬砖就能躺着挣钱?“我就一台电脑,要求不高月入过万就好”...

    随着网络游戏的日益普及 游戏搬砖 已经和代练 陪玩一样 成为很多玩家专职 兼职的工作 对于这样的一份职业 有人不屑 有人眼浅 更多人在看了有搬砖党换一套海景房的新闻 觉得这个职业就是个可以一夜暴富或者躺着挣钱的职业 因而也想自己下海尝试 事
  • 基于stm32作品设计:多功能氛围灯、手机APP无线控制ws2812,MCU无线升级程序

    文章目录 一 作品背景 二 功能设计与实现过程 三 实现基础功能 一 首先是要选材 二 原理图设计 二 第一版本PCB设计 三 焊接PCB板 四 编写单片机程序 五 下载程序验证 四 外壳设计 一 CAD图纸设计 二 磨砂亚克力板 五 重新
  • [Git] 代码管理之 Git(六)Git rebase 压缩提交历史

    我们在工作中 可能会出现这样的情况 一项工作由好几个同事同时完成 然后每个人针对当前的feature都有对应的提交 那么就会造成同一个feature有多次提交的这样的冗余存在 除此之外 如果我们自己针对同一个feature的每天的提交以及一
  • JVM(六)方法调用(补充知识)

    方法调用并不等同于方法中的代码被执行 方法调用阶段唯一的任务就是确定被调用方法的版本 即调用哪一个方法 暂时还未涉及方法内部的具体运行过程 一切方法调用在Class文件里面存储的都只是符号引用 而不是方法在实际运行时内存布局中的入口地址 也
  • 关于H264相关的EBSP,RBSP,SODP的说明

    1 关于H264相关的EBSP RBSP SODP的说明 1 EBSP 扩展字节序列载荷 Encapsulated Byte Sequence Payload 它去掉了00 00 01 00 00 00 01这些起始码 但包含了0x3防止竞
  • 时序约束优先级_静态时序分析圣经翻译计划——附录A:SDC

    本附录将介绍1 7版本的SDC格式 此格式主要用于指定设计的时序约束 它不包含任何特定工具的命令 例如链接 link 和编译 compile 它是一个文本文件 可以手写或由程序创建 并由程序读取 某些SDC命令仅适用于实现 implemen
  • @Around简单使用示例——SpringAOP增强处理

    Around的作用 既可以在目标方法之前织入增强动作 也可以在执行目标方法之后织入增强动作 可以决定目标方法在什么时候执行 如何执行 甚至可以完全阻止目标目标方法的执行 可以改变执行目标方法的参数值 也可以改变执行目标方法之后的返回值 当需
  • 如何搭建Spring开发环境呢?

    转自 如何搭建Spring开发环境呢 下文讲述搭建Spring开发环境的方法分享 如下所示 由于Spring是基于Java代码的一个框架 所以在Spring环境搭建之前 我们需为开发环境安装好 JDK Java开发环境 Eclipse Ja
  • 来自ebay内部的「软件测试」学习资料,覆盖GUI、API自动化、代码级测试及性能测试等,Python等,拿走不谢!...

    在软件测试领域从业蛮久了 常有人会问我 刚入测试一年 很迷茫 觉得没啥好做的 测试在公司真的不受重视 我是不是去转型做开发会更好 资深的测试架构师的发展路径是怎么样的 我平时该怎么学习 我估计不少人有这样的想法 甚至你也会被身边的人所影响
  • React 中constructor 作用

    React 中constructor 作用 react中的constructor大体有两个作用 1 初始化this state 2 纠正方法的this的指向 constructor props super props this state
  • 大数据毕设 python+深度学习+opencv实现植物识别算法系统

    文章目录 0 前言 2 相关技术 2 1 VGG Net模型 2 2 VGG Net在植物识别的优势 1 卷积核 池化核大小固定 2 特征提取更全面 3 网络训练误差收敛速度较快 3 VGG Net的搭建 3 1 Tornado简介 1 优
  • LWIP学习笔记(2)---IP协议实现细节

    IP头 收到的数据首先保存在pbuf结构中 The IPv4 header struct ip hdr version header length PACK STRUCT FLD 8 u8 t v hl type of service PA
  • C++11线程库 (六) 条件变量 Condition variables

    一 什么是条件变量 条件变量类 condition variable 是一个同步原语 它可以在同一时间阻塞一个线程或者多个线程 直到其他线程改变了共享变量 条件 并通知 primitive 原语 表达的是基础 基本的 是其他复杂应用的构建基
  • 基于高德地图的描点操作,监听地图缩放,展示合理数量的marker

    1 根据两点经纬度算两点之间的距离函数 function Rad d return d Math PI 180 0 经纬度转换成三角函数中度分表形式 计算距离 参数分别为第一点的纬度 经度 第二点的纬度 经度 function GetDis
  • java缓存面试问题_分布式缓存的面试题5

    1 面试题 如何保证Redis的高并发和高可用 redis的主从复制原理能介绍一下么 redis的哨兵原理能介绍一下么 2 面试官心里分析 其实问这个问题 主要是考考你 redis单机能承载多高并发 如果单机扛不住如何扩容抗更多的并发 re
  • R语言数据输入

    一 使用键盘输入数据 在导入数据比较少的时候 我们使用这种方法 R中的函数 edit 会自动调用一个允许手动输入数据的文本编辑器 具体步骤如下 1 创建一个空数据框 或矩阵 其中变量名和变量的模式需与理想中的最终数据集一致 2 针对这个数据
  • java使用aspose将word,excel,ppt转pdf

    1 测试环境springboot jdk1 8 aspose cells 8 5 2 jar 用于转换xls aspose words 16 8 0 jdk16 jar 用于转换doc 2 所用jar 签名百度网盘地址 链接 https p
  • Pytorch加载自己的数据集(以图片格式的Mnist数据集为例)

    文章目录 Pytorch加载自己的数据集 以图片格式的Mnist数据集为例 前言 一 数据集转换 二 构建自己的数据集 1 引入库 2 构建MnistDataset类 3 搭建网络模型 三 完整代码 总结 Pytorch加载自己的数据集 以