Reid训练代码之数据集处理

2023-10-26

本篇文章是对yolov5_reid这篇文章训练部分的详解。

该项目目录为:

.
|-- config # reid输入大小,数据集名称,损失函数等配置
|-- configs # 训练时期超参数定义
|-- data # 存储数据集和数据处理等代码,以及yolov5类别名称等
|-- engine # 训练和测试mAP,rank等相关代码
|-- layers # loss定义
|-- logs # 训练好的权重将存储在这
|-- modeling # 定义的网络
|-- output  # 输出
|-- person_search  # 人员查找
|-- readme.md # readme
|-- solver # 优化器相关代码
|-- tests
|-- tools  # 训练和测试代码
|-- utils  # logger等相关代码
`-- weights  # 存放预权重

数据集加载:

数据集加载与处理,需要调用头文件:

from data import make_data_loader

 make_data_loader

传入参数为cfg,训练中的相关配置文件。

build_transforms函数

这个函数传入函数有两个,cfg是配置文件,is_train=True表示训练。normalize_transform是计算数据集的均值和方差。均值为[0.485, 0.456, 0.406],方差为[0.229, 0.224, 0.225](可以看配置文件)。

如果is_train=True的时候,对数据集进行处理。

T.Resize:将图像调整为[256,128]大小;

T.RandomHorizontalFlip(p=cfg.INPUT.PROB):随机水平翻转,设置为0.5;

T.Pad:padding值,10;

T.ToTensor():转为tensor;

normalize_transform:图像的均值和方差;

RandomErasing:数据增强(随机擦除),将图片内的某块区域填充相同的像素值,从而将该区域的图片信息遮盖,强迫模型学习该区域外的特征进行识别,在一定程度上避免模型陷入局部最优,从而提高模型的泛化能力。

将多个变换组合在一起。

如果测试的时候,is_train=False,不用数据增强。

def build_transforms(cfg, is_train=True):
    normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
    if is_train:
        transform = T.Compose([
            T.Resize(cfg.INPUT.SIZE_TRAIN),
            T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
            T.Pad(cfg.INPUT.PADDING),
            T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
            T.ToTensor(),
            normalize_transform,
            RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN)
        ])
    else:
        transform = T.Compose([
            T.Resize(cfg.INPUT.SIZE_TEST),
            T.ToTensor(),
            normalize_transform
        ])

    return transform

 继续返回make_data_loader函数。

通过build_transforms仅仅返回的是训练和测试需要用的一些数据处理方面的"规则"。

train_transforms = build_transforms(cfg, is_train=True)
val_transforms = build_transforms(cfg, is_train=False)

 num_workers:获取进程数量,我这里是4.

num_workers = cfg.DATALOADER.NUM_WORKERS

init_dataset函数 

传入参数name:数据集的名称,我这里是mark1501;

还传入了数据集的路径:我这里是./data

该函数主要是判断支持的数据集格式。

def init_dataset(name, *args, **kwargs):
    if name not in __factory.keys():
        raise KeyError("Unknown datasets: {}".format(name))
    return __factory[name](*args, **kwargs)

继续看make_data_loader函数。

训练时分类的数量,这里是751。注意!在训练的时候是751,在测试的是1501.

num_classes = dataset.num_train_pids

 ImageDataset函数

该类基础Dataset,因此说明该类是做数据集处理的。上面我们说到的build_transforms仅仅是一些数据集处理的"规则"。

在调用该类的时候,传入两个参数,一个是dataset.train[训练数据集的图片路径],另一个就是train_transforms[处理的规则]。所以这个类就知道了,是用上面定义的“规则”来处理我们的数据集。

在下面这段代码中self.dataset[index]就是对数据集遍历(__getitem__就是迭代器),加入此时index=0,此时获得为:('./data\\Market1501\\bounding_box_train\\0002_c1s1_000451_03.jpg', 0, 0)。img_path就为数据集的路径,pid为类,camid为相机id[这个需要了解markt1501数据集]。

read_image函数就是通过PIL读取的图像。然后用transform处理。返回值有四个,img[数据增强后的图像],pid[类别],camid[相机id],img_path[图像路径]。

class ImageDataset(Dataset):
    """Image Person ReID Dataset"""

    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img_path, pid, camid = self.dataset[index]
        img = read_image(img_path)

        if self.transform is not None:
            img = self.transform(img)

        return img, pid, camid, img_path

下面的两个图就是增强后的效果 

 

 


接下来再回到make_data_loader函数。

下面一段代码是对处理后的数据集进行加载,这里调用的torch中DataLoader函数。传入的参数有batch,我这里是8,shuffle表示打乱,collate_fn这个很重要,就是把这些按batch处理。

    if cfg.DATALOADER.SAMPLER == 'softmax':
        train_loader = DataLoader(
            train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
            collate_fn=train_collate_fn
        )
    else:
        train_loader = DataLoader(
            train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
            sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
            num_workers=num_workers, collate_fn=train_collate_fn
        )

同理,验证集也是这样处理。

最终返回训练集,验证集,数据集长度(数量),类别:751

完整代码

def make_data_loader(cfg):
    train_transforms = build_transforms(cfg, is_train=True)
    val_transforms = build_transforms(cfg, is_train=False)
    num_workers = cfg.DATALOADER.NUM_WORKERS
    if len(cfg.DATASETS.NAMES) == 1:
        dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)
    else:
        # TODO: add multi dataset to train
        dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)

    num_classes = dataset.num_train_pids
    train_set = ImageDataset(dataset.train, train_transforms)
    if cfg.DATALOADER.SAMPLER == 'softmax':
        train_loader = DataLoader(
            train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
            collate_fn=train_collate_fn
        )
    else:
        train_loader = DataLoader(
            train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
            sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
            num_workers=num_workers, collate_fn=train_collate_fn
        )

    val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
    val_loader = DataLoader(
        val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
        collate_fn=val_collate_fn
    )
    return train_loader, val_loader, len(dataset.query), num_classes

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Reid训练代码之数据集处理 的相关文章

随机推荐

  • 多线程案例(2) - 阻塞队列

    目录 一 阻塞队列 1 1 什么是阻塞队列 1 2 生产者消费者模型 1 3 标准库中的阻塞队列 1 4 阻塞队列的实现 一 阻塞队列 1 1 什么是阻塞队列 阻塞队列 BlockingQueue 是一种特殊的队列 遵循 先进先出 的原则
  • Deep-Learning-YOLOV4实践:ScaledYOLOv4模型训练自己的数据集调试问题总结

    error error1 CUDA out of memory error2 TypeError can t convert cuda error Deep Learning YOLOV4实践 ScaledYOLOv4 数据集制作 Deep
  • 知识库-kafka shell脚本用法

    脚本名称 用途描述 connect distributed sh 连接kafka集群模式 connect standalone sh 连接kafka单机模式 kafka acls sh todo kafka broker api versi
  • 一篇搞定dockerfile定制镜像过程

    一 定制镜像的两种方法 1 docker commit 通过已有容器创建镜像 提交容器快照作为镜像 不推荐 2 docker build 就是本文着重讲的dockerfile创建镜像方式 推荐 docker commit无法还原镜像制作过程
  • 【Linux学习】epoll详解

    什么是epoll epoll是什么 按照man手册的说法 是为处理大批量句柄而作了改进的poll 当然 这不是2 6内核才有的 它是在2 5 44内核中被引进的 epoll 4 is a new API introduced in Linu
  • centos7运行vue项目问题汇总

    一 node踩坑之This is probably not a problem with npm There is likely additional logging output above 错误 解决步骤 1 可能由于种种版本更新的原因
  • windbg 常用命令详解

    一 1 address eax 查看对应内存页的属性 2 vertarget 显示当前进程的大致信息 3 peb 显示process Environment Block 4 lmvm 可以查看任意一个dll的详细信息 例如 我们查看cyus
  • java中List按照指定字段排序工具类

    文章标题 java中List按照指定字段排序工具类 文章地址 http blog csdn net 5iasp article details 17717179 包括如下几个类 1 实体类 package com newyear wish
  • 【C语言】螺旋数组

    螺旋数组的打印 程序C语言代码 更改宏定义的数值即可实现螺旋数组行列的变化 include stdio h define ROW 5 宏定义行 define COL 5 宏定义列 void main int arr ROW COL 0 in
  • Python Decorators(二):Decorator参数

    Python Decorators II Decorator Arguments October 19 2008 本文是Python 3 Patterns Idioms Python3之模式和用法 一书的章节节选第二部分 点击这里阅读第一部
  • Kotlin数据类型(一:数据类型)

    一 Boolean Boolean类型有两种类型的 true flase val a Boolean true val b Boolean false 二 Number数据类型 package net println kotlin auth
  • 强化学习 DQN 速成

    强化学习 DQN 速成 这是对 深度强化学习 王树森 张志华 中 DQN 部分的缩写以及部分内容的个人解读 书中的 DQN 是一个相对终极版本的存在 相信体量会比网络上其他资料要大很多 基本概念 我们通过贪吃蛇来引入几个基本概念 符号 中文
  • Flink Windows(窗口)详解

    Windows 窗口 Windows是流计算的核心 Windows将流分成有限大小的 buckets 我们可以在其上应用聚合计算 ProcessWindowFunction ReduceFunction AggregateFunction或
  • MySQL redo log和undo log

    Redo Log REDO LOG称为重做日志 当MySQL服务器意外崩溃或者宕机后 保证已经提交的事务持久化到磁盘中 持久性 InnoDB是以页为单位去操作记录的 增删改查都会加载整个页到buffer pool中 磁盘 gt 内存 事务中
  • Matlab矩阵处理

    一 通用的特殊矩阵 zero m zeros m n zero size A 产生全为零的矩阵 格式下同 ones 产生全为一的矩度阵 eye 产生单位矩阵 rand 产生在 0 1 区间均匀分布的矩阵 randn 产生均值为0 方差为1的
  • C计数问题---2023河南萌新联赛第(三)场:郑州大学

    解析 n 可以分成两个数 记录每个数的因子对数 乘起来即可 注意当因子相同时 只 1 include
  • Java文件类型校验之Apache Tika

    一 背景 判断文件类型一般可采用两种方式 1 后缀名判断 简单易操作 但无法准确判断类型 2 文件头信息判断 通常可以判断文件类型 但有些文件类型无法判断 如word和excel头信息的前几个字节是一样的 无法判断 使用apache tik
  • flink watermark 生成机制与总结

    flink watermark 生成机制与总结 watermark 介绍 watermark生成方式 watermark 的生成值算法策略 watermark策略设置代码 watermark源码分析 watermark源码调用流程debug
  • 你知道几种延迟队列的实现方案?

    在开发中 往往会遇到一些关于延时任务的需求 例如 生成订单30分钟未支付 则自动取消 生成订单60秒后 给用户发短信 对上述的任务 我们给一个专业的名字来形容 那就是延时任务 那么这里就会产生一个问题 这个延时任务和定时任务的区别究竟在哪里
  • Reid训练代码之数据集处理

    本篇文章是对yolov5 reid这篇文章训练部分的详解 该项目目录为 config reid输入大小 数据集名称 损失函数等配置 configs 训练时期超参数定义 data 存储数据集和数据处理等代码 以及yolov5类别名称等 eng