【PyTorch学习】(三)自定义Datasets

2023-11-18

torchvision.datasets源码地址:https://github.com/pytorch/vision/blob/master/torchvision/datasets


前两篇从搭建经典的ResNet,DenseNet入手简单的了解了下PyTorch搭建网络的方式,但训练一个模型光光搭建好一个网络是不够的,正所谓巧妇难为无米之炊,如何将数据处理成网络可以传递的Tensor也尤为重要,而数据准备过程最最最最最重要的就是DatasetsDataloader两部分!

torchvision.datasets.ImageFolder就是官方给出的一个datasets的事例,具体使用直接贴上官方tutorial上的代码供参考:

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

但由于torchvision.datasets.ImageFolder函数的使用必须对数据的放置有要求,必须在data_dir目录下放置train和val两个文件夹,然后每个文件夹下,每一类图片单独放在一个文件夹里。官方的例子是ants和bees,所以在train和val文件夹下都有ants和bees这两个文件夹,分别放置相应的文件。

那么问题就来了,我们通常打完标签,是不会根据标签进行分类,而且在进行目标检测时一张图可能对应有多个标签,而是通过一个xml文件或者json文件用于记录label信息,所以是不满足ImageFolder的要求的。

所以根据实际数据情况,自定义Datasets就很关键,接下来我们就根据ImageFolder的函数形式,顺藤摸瓜从头来看如何自定义一个Datasets!


一、torch.utils.data.Dataset

首先,可以看到ImageFolder类继承了DatasetFolder类,DatasetFolder类又继承了torch一个基础的抽象类torch.utils.data.Dataset类。

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

自定义Datasets的关键就是重载 "__len__"和"__getitem__"两个函数!而 "__add__"函数的作用是使得类定义对象拥有"object1 + object2"的功能,一般情况不需要重载该函数。

  1. __len__函数:使得类对象拥有 "len(object)"功能,返回dataset的size。

  2. __getitem__函数:使得类对象拥有"object[index]"功能,可以用索引i去获得第i+1个样本。

二、torchvision.datasets.CocoDetection

再来看看同样继承于torch.utils.data.Dataset的CocoDetection dataset是如何定义上述两个函数的!

1.__init__:

def __init__(self, root, annFile, transform=None, target_transform=None):
    # 从cocoapi导入pycocotools下的COCO类
    from pycocotools.coco import COCO
    self.root = root
    # 初始化一个COCO对象
    self.coco = COCO(annFile)
    # 将每张图unique的id属性转化为list存储在self.ids中
    self.ids = list(self.coco.imgs.keys())
    self.transform = transform
    self.target_transform = target_transform

(1)初始化函数可以接受四个参数:

  • root: COCO形式的数据集的根目录地址。
  • annFile: COCO形式的数据集中.json文件的目录地址。
  • transform: 原始图像是否需要进行变换(数据增强,默认是None不做增强)。
  • target_transform: 标签是否需要进行变换(标签变换需要和原始图像变换相对应,默认是None不做增强)。

(2)初始化COCO对象时,将.json文件解析为字典形式导入内存,并创建调用createIndex()创建索引

(3)self.coco.imgs是以每张图unique的id作为key,json文件images下每一image信息作为value的一个字典。

2.__len__:

def __len__(self):
    # 因为图片的id是unique的,所以self.ids的长度就等于总图片数
    return len(self.ids)

3.__getitem__:

def __getitem__(self, index):
    """
    Args:
        index (int): Index

    Returns:
        tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
    """
    coco = self.coco
    # 通过索引获得图片的id
    img_id = self.ids[index]
    # 再通过getAnnIds方法利用img_id找到对应的anno_id
    ann_ids = coco.getAnnIds(imgIds=img_id)
    # 根据anno_id和标签之间的映射关系,解析出标签target
    target = coco.loadAnns(ann_ids)
       
    path = coco.loadImgs(img_id)[0]['file_name']
    # 根据每张图的file_name结合之前传入的图片放置的根目录读取图片信息
    img = Image.open(os.path.join(self.root, path)).convert('RGB')
    # 判断是否需要进行数据增强
    if self.transform is not None:
        img = self.transform(img)
    # 判断标签是否需要进行变换
    if self.target_transform is not None:
        target = self.target_transform(target)
        
    # 最终返回值形式可以根据自己需要进行设计。此处为一个tuple,包含一张图片以及对应的标签。
    return img, target

三、自定义人脸关键点dataset

以下这个例子就是自定义的FaceLandmarksDataset,效果是从.csv文件中读取每张图上的68个人脸面部关键点的坐标x,y,然后根据.csv文件中对应的图片名,读取相应的图片,然后返回值是一个sample字典,包含'image'和'landmarks'两个key。

class FaceLandmarksDataset(Dataset):
    
    def __init__(self, root_dir, csv_file, transform=None):
        self.root_dir = root_dir
        self.landmarks_frame = pd.read_csv(csv_file)
        self.transform = transform
        
    def __len__(self):
        return len(self.landmarks_frame)
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

数据准备阶段datasets部分就简单介绍完了,下篇继续介绍另一个关键部分dataloader!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【PyTorch学习】(三)自定义Datasets 的相关文章

随机推荐

  • mpvue实现微信小程序样式切换以及本地缓存

    功能描述 在页面A的添加应用中点击 添加 跳转到展示所有应用的页面B 通过点击开关 在页面A中展示所开启的应用 效果展示 代码 页面B代码 div class itembox div class boxhead img div class
  • [HDU 5079][2014 Asia AnShan Regional Contest]Square(DP套DP)

    题目链接 http acm hdu edu cn showproblem php pid 5079 题目大意 给你一个 n n n 8 n middot n n le 8 的棋盘 上面有一些格子必须是黑色 其它可以染黑或者染白 对于一个棋盘
  • python实现逻辑回归三种方法_纯Python实现逻辑回归

    前几天使用后sklearn实现了逻辑回归 这里用纯python实现逻辑回归 首先 我们定义一个sigmoid函数 def sigmoid inX sigmoid函数 return 1 0 1 exp inX 这里使用梯度上升进行逻辑回归 梯
  • 【编译之美】【5. 代码优化:数据流分析】

    有些优化只能在全局优化中做 在本地优化中做不了 比如 代码移动 Code motion 能够将代码从一个基本块挪到另一个基本块 比如从循环内部挪到循环外部 来减少不必要的计算 循环剥离 部分冗余删除 Partial Redundancy E
  • 角落的开发工具集之Vs(Visual Studio)2017插件推荐

    工具善其事 必先利其器 装好这些插件让vs更上一层楼 因为最近录制视频的缘故 很多朋友都在QQ群留言 或者微信公众号私信我 问我一些工具和一些插件啊 怎么使用的啊 那么今天我忙里偷闲整理一下清单 然后在这里面公布出来 Visual Stud
  • 毕业设计-基于深度学习的花卉识别分类

    目录 前言 课题背景和意义 实现技术思路 一 花卉识别相关理论基础 二 基于 ResNeXt 和迁移学习的花卉种类识别 三 基于 EfficientNet 和迁移学习的花卉种类识别 实现效果图样例 最后 前言 大四是整个大学期间最忙碌的时光
  • scss 中公共变量的导出方法:export

    前言 在使用vue或者react开发项目时都会用到scss 或者less等的扩展语言 那么肯定会有公共变量提取与使用 这篇文章就是记录如何导出公共css变量的 export 关键词 menuText bfcbd9 menuActiveTex
  • React 相关方法(API)介绍-元素与组件操作

    JSX可以减少定义组件的复杂性 但对于React来说JSX并不是必须的 JSX标签最终会被转换为原生的JavaScript 除使用JSX语法外 还可以使用React提供的API来创建组件 本文将介绍使用React创建元素 及一些React中
  • 类与对象基础

    1 面向对象概述 面向过程就是分析出解决问题所需要的步骤 然后用函数把这些步骤一一实现 使用的时候依次调用就可以了 面向对象则是把构成问题的事务按照一定规则划分为多个独立的对象 然后通过调用对象的方法来解决问题 当然 一个应用程序会包含多个
  • JAVA构造方法与static 关键字

    JAVA的构造方法 什么是构造方法 构造方法用来生成一个实例化的对象并对对象实例中的成员变量进行初始化 采用new创建对象时 构造方法被执行 构造方法的方法名必须和类名保持一致 注意 构造方法没有返回值 不可以加void 只能用 publi
  • 设计模式之命令模式

    优质资源分享 学习路线指引 点击解锁 知识定位 人群定位 Python实战微信订餐小程序 进阶级 本课程是python flask 微信小程序的完美结合 从项目搭建到腾讯云部署上线 打造一个全栈订餐系统 Python量化交易实战 入门级 手
  • 【2021应用上架】超详细开发者账号申请&应用上架审核经验整理

    一 准备阶段需要注意的 1 上架前开发者账号申请 申请的主体确定 在公司有多个主体的情况下 用哪个公司主体认证开发者 上架APP时需要考虑到应用相关的各种材料申请在哪个公司名下 材料所属公司主体与开发者账号主体不一致的情况需要开发者花费时间
  • vue节流和防抖

    节流 节流是间隔执行 在定时器到时间后再清空定时器 函数将每个 n 秒执行一次 在内部定义一个定时器和一个开关变量 初始化变量为true 执行定期器前判断变量是否false 就return 为true 如果是继续执行 并且把变量赋值为fal
  • 使用JSON.toJSONString时,出现“$ref”怎么办?服务器返回对象显示$ref怎么解决?

    现象 代码 Map
  • nvm 和 nrm安装使用

    前端工具推荐 nvm Node 版本管理工具 和 nrm 管理npm源 一 nvm 如果直接将 node 安装到电脑上 通常只能安装某个特定的版本 如 v18 12 1 而某些老项目可能只支持老版本的 node 如 v14 19 3 这时候
  • UNIX网络编程卷一 学习笔记 第三十章 客户/服务器程序设计范式

    开发一个Unix服务器程序时 我们本书做过的进程控制 1 迭代服务器 iterative server 它的适用情形极为有限 因为这样的服务器在完成对当前客户的服务前无法处理已等待服务的新客户 2 并发服务器 concurrent serv
  • 解决win10升级到win11,打不开安全中心的问题(亲测有效,已修复)

    相信很多人也碰上过这种问题 升级到了win 11 但是安全中心打不开了 报错 需要使用新应用以打开此windowsdefender链接 但是微软的应用商店并没有这个软件 然后我实验了一种方法 1 去微软的应用商店 Microsoft Sto
  • mysql中如何操作varchar类型的日期进行比较、排序等操作

    在mysql使用过程中 日期一般都是以datetime timestamp等格式进行存储的 但有时会因为特殊的需求或历史原因 日期的存储格式是varchar 那么我们该如何处理这个varchar格式的日期数据呢 使用函数 STR TO DA
  • SSM框架基于JSP犬舍寄养系统

    项目介绍 SSM框架基于JSP犬舍寄养系统的设计与实现 高清视频演示 SSM框架基于JSP犬舍寄养系统的设计与实现 安装视频演示 SSM框架基于JSP犬舍寄养系统的设计与实现 系统说明 1 前台功能模块 首先注册会员 登录进平台 然后选择自
  • 【PyTorch学习】(三)自定义Datasets

    torchvision datasets源码地址 https github com pytorch vision blob master torchvision datasets 前两篇从搭建经典的ResNet DenseNet入手简单的了