Pytorch Torch.utils.data.Sampler

2023-11-20

对于 可迭代样式的数据集,数据加载顺序完全由用户定义的可迭代样式控制。这样可以更轻松地实现块读取和动态批处理大小(例如,通过每次生成一个批处理的样本)。
为了从数据集中读取数据,pytorch提供了Sampler基类与多个子类实现不同方式的数据采样

  • Sequential Sampler(顺序采样)
  • Random Sampler(随机采样)
  • Subset Random Sampler(子集随机采样)
  • Weighted Random Sampler**(加权随机采样)**等等。
1、基类Sampler
class Sampler(object):
    r"""Base class for all Samplers.
    """
    def __init__(self, data_source):
        pass
    def __iter__(self):
        raise NotImplementedError

对于所有的采样器来说,都需要继承Sampler类,必须实现的方法为__iter__(),也就是定义迭代器行为,返回可迭代对象。除此之外,Sampler类并没有定义任何其它的方法。

2、顺序采样Sequential Sampler
class SequentialSampler(Sampler):
    r"""Samples elements sequentially, always in the same order.
    Arguments:
        data_source (Dataset): dataset to sample from
    """
    def __init__(self, data_source):
        self.data_source = data_source
    def __iter__(self):
        return iter(range(len(self.data_source)))
    def __len__(self):
        return len(self.data_source)

顺序采样类并没有定义过多的方法,其中初始化方法仅仅需要一个Dataset类对象作为参数。对于__len__()只负责返回数据源包含的数据个数;iter()方法负责返回一个可迭代对象,这个可迭代对象是由range产生的顺序数值序列,也就是说迭代是按照顺序进行的。前述几种方法都只需要self.data_source实现了__len__()方法,因为这几种方法都仅仅使用了len(self.data_source)函数。所以下面采用同样实现了__len__()的list类型来代替Dataset类型做测试:

# 定义数据和对应的采样器
data = list([17, 22, 3, 41, 8])
seq_sampler = sampler.SequentialSampler(data_source=data)
# 迭代获取采样器生成的索引
for index in seq_sampler:
    print("index: {}, data: {}".format(str(index), str(data[index])))

得到下面的输出,说明Sequential Sampler产生的索引是顺序索引

index: 0, data: 17
index: 1, data: 22
index: 2, data: 3
index: 3, data: 41
index: 4, data: 8
3、随机采样RandomSampler
class RandomSampler(Sampler):
    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
    If with replacement, then user can specify :attr:`num_samples` to draw.
    Arguments:
        data_source (Dataset): dataset to sample from
        replacement (bool): samples are drawn with replacement if ``True``, default=``False``
        num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
            is supposed to be specified only when `replacement` is ``True``.
    """
    def __init__(self, data_source, replacement=False, num_samples=None):
        self.data_source = data_source
        # 这个参数控制的应该为是否重复采样
        self.replacement = replacement
        self._num_samples = num_samples
    # 省略类型检查
    @property
    def num_samples(self):
        # dataset size might change at runtime
        # 初始化时不传入num_samples的时候使用数据源的长度
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples
    # 返回数据集长度
    def __len__(self):
        return self.num_samples

最重要的是__iter__()方法,定义了核心的索引生成行为:

def __iter__(self):
    n = len(self.data_source)
    if self.replacement:
        # 生成的随机数是可能重复的
        return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
    # 生成的随机数是不重复的
    return iter(torch.randperm(n).tolist())

其中if判断处返回了两种随机值,根据是否在初始化中给出replacement参数决定是否重复采样,区别核心在于**randint()函数生成的随机数学列是可能包含重复数值的,而randperm()函数生成的随机数序列是绝对不包含重复数值的,**下面分别测试是否使用replacement作为输入的情况,首先是不使用时:

ran_sampler = sampler.RandomSampler(data_source=data)
# 得到下面输出
ran_sampler = sampler.RandomSampler(data_source=data)
index: 0, data: 17
index: 2, data: 3
index: 3, data: 41
index: 4, data: 8
index: 1, data: 22

可以看出生成的随机索引是不重复的,下面是采用replacement参数的情况:

ran_sampler = sampler.RandomSampler(data_source=data, replacement=True)
# 得到下面的输出
index: 0, data: 17
index: 4, data: 8
index: 3, data: 41
index: 4, data: 8
index: 2, data: 3

此时生成的随机索引是有重复的(4出现两次),也就是说第“4”条数据可能会被重复的采样。

4、子集随机采样Subset Random Sampler
class SubsetRandomSampler(Sampler):
    r"""Samples elements randomly from a given list of indices, without replacement.
    Arguments:
        indices (sequence): a sequence of indices
    """
    def __init__(self, indices):
        # 数据集的切片,比如划分训练集和测试集
        self.indices = indices
    def __iter__(self):
        # 以元组形式返回不重复打乱后的“数据”
        return (self.indices[i] for i in torch.randperm(len(self.indices)))
    def __len__(self):
        return len(self.indices)

上述代码中__len__()的作用与前面几个类的相同,依旧是返回数据集的长度,区别在于__iter__()返回的并不是随机数序列,而是通过随机数序列作为indices的索引,进而返回打乱的数据本身。需要注意的仍然是采样是不重复的,也是通过randperm()函数实现的。按照网上可以搜集到的资料,Subset Random Sampler应该用于训练集、测试集和验证集的划分,下面将data划分为train和val两个部分,再次指出__iter__()返回的的不是索引,而是索引对应的数据:

sub_sampler_train = sampler.SubsetRandomSampler(indices=data[0:2])
sub_sampler_val = sampler.SubsetRandomSampler(indices=data[2:])
# 下面是train输出
index: 17
index: 22
*************
# 下面是val输出
index: 8
index: 41
index: 3
5、加权随机采样WeightedRandomSampler
class WeightedRandomSampler(Sampler):
    r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights)."""
    def __init__(self, weights, num_samples, replacement=True):
         # ...省略类型检查
         # weights用于确定生成索引的权重
        self.weights = torch.as_tensor(weights, dtype=torch.double)
        self.num_samples = num_samples
        # 用于控制是否对数据进行有放回采样
        self.replacement = replacement
    def __iter__(self):
        # 按照加权返回随机索引值
        return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())

对于Weighted Random Sampler类的__init__()来说,replacement参数依旧用于控制采样是否是有放回的;num_sampler用于控制生成的个数;weights参数对应的是“样本”的权重而不是“类别的权重”。其中__iter__()方法返回的数值为随机数序列,只不过生成的随机数序列是按照weights指定的权重确定的,测试代码如下:

# 位置[0]的权重为0,位置[1]的权重为10,其余位置权重均为1.1
weights = torch.Tensor([0, 10, 1.1, 1.1, 1.1, 1.1, 1.1])
wei_sampler = sampler.WeightedRandomSampler(weights=weights, num_samples=6, replacement=True)
# 下面是输出:
index: 1
index: 2
index: 3
index: 4
index: 1
index: 1

从输出可以看出,位置[1]由于权重较大,被采样的次数较多,位置[0]由于权重为0所以没有被采样到,其余位置权重低所以都仅仅被采样一次。

6、批采样BatchSampler
class BatchSampler(Sampler):
    r"""Wraps another sampler to yield a mini-batch of indices."""
    def __init__(self, sampler, batch_size, drop_last):# ...省略类型检查
        # 定义使用何种采样器Sampler
        self.sampler = sampler
        self.batch_size = batch_size
        # 是否在采样个数小于batch_size时剔除本次采样
        self.drop_last = drop_last
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            # 如果采样个数和batch_size相等则本次采样完成
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        # for结束后在不需要剔除不足batch_size的采样个数时返回当前batch        
        if len(batch) > 0 and not self.drop_last:
            yield batch
    def __len__(self):
        # 在不进行剔除时,数据的长度就是采样器索引的长度
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

在定义好各种采样器以后,需要进行“batch”的采样。BatchSampler类的__init__()函数中sampler参数对应前面介绍的XxxSampler类实例,也就是采样方式的定义;drop_last为“True”时,如果采样得到的数据个数小于batch_size则抛弃本个batch的数据。对于__init__()中的for循环,作用应该是以“生成器”的方式不断的从sampler中获取batch,比如sampler中包含1.5个batch_size的数据,那么在drop_last为False的情况下,for循环就会yield出个batch,下面的例子中batch sampler采用的采样器为顺序采样器:

seq_sampler = sampler.SequentialSampler(data_source=data)
batch_sampler = sampler.BatchSampler(seq_sampler, 3, False)
# 下面是输出
batch: [0, 1, 2]
batch: [3, 4]
总结:

对于深度神经网络来说,数据的读取方式对网络的优化过程是有很大影响的。比如常见的策略为“打乱”数据集,使得每个epoch中各个batch包含的数据是不同的。在训练网络时,应该从上述的几种采样策略中进行选择,从而确定最优方式。

参考:

【1】https://zhuanlan.zhihu.com/p/100280685

【2】https://blog.csdn.net/WangWen123_111/article/details/104492513

【3】https://blog.csdn.net/tsq292978891/article/details/79414512?utm_medium=distribute.pc_relevant.none-task-blog-baidujs_title-0&spm=1001.2101.3001.4242

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch Torch.utils.data.Sampler 的相关文章

随机推荐

  • 数学(五) -- LC[415]&[455] 字符串相加与两数相加

    1 字符串相加 1 1 题目描述 给定两个字符串形式的非负整数 num1 和num2 计算它们的和并同样以字符串形式返回 你不能使用任何內建的用于处理大整数的库 比如 BigInteger 也不能直接将输入的字符串转换为整数形式 示例 1
  • qtdesigner界面美化

    文章目录 前言 一 QSS 1 1 QSS选择器介绍 2 2 使用 二 自定义属性 2 1 添加自定义属性 前言 pyqt5能快速构建界面 但是你会发现构建出来的界面没有像我们平常用的客户端界面一样美观 现在 就让我学习一下如何美化界面 本
  • 如何防止过拟合和欠拟合

    过拟合和欠拟合是模型训练过程中经常出现的问题 两种情况正好相反 现将两者的定义及如何防止进行简要总结 1 过拟合 1 1 定义 是指模型对于训练数据拟合呈现过当的情况 反映到评估指标上就是模型在训练集上的表现很好 但是在测试集上的表现较差
  • 如何终止java线程

    终止线程的三种方法 有三种方法可以使终止线程 1 使用退出标志 使线程正常退出 也就是当run方法完成后线程终止 2 使用stop方法强行终止线程 这个方法不推荐使用 因为stop和suspend resume一样 也可能发生不可预料的结果
  • n行Python代码系列:两行代码去除抖音快手短视频尾部Logo

    老猿Python博文目录 https blog csdn net LaoYuanPython 一 引言 最近看到好几篇类似 n行Python代码 的博文 看起来还挺不错 简洁 实用 传播了知识 带来了阅读量 撩动了老猿的心 决定跟风一把 推
  • 基于Springboot+Vue图书商城系统

    目录 1 系统架构 2 系统介绍 3 运行环境 4 系统演示 5 前端搭建 6 功能展示 7 代码展示 1 系统架构 后台 SpringBoot Mybatis plus Mybatis Hutool工具包 lombok插件 前台 Vue
  • UE4 制作导出Content目录下某个文件夹内所有模型的六视图并将模型资源文件复制到指定文件夹的插件

    一 新建空白插件 在Bulid cs内加入两个模块 EditorSubsystem UnrealEd PublicDependencyModuleNames AddRange new string Core EditorSubsystem
  • 用vue写一个学校官网

    首先 你需要准备一个文本编辑器 如 Sublime Text VSCode 等 和浏览器 然后 你需要安装 Vue 你可以通过使用 npm 命令来安装 npminstall g vue
  • Ubuntu Server 22.04 安装桌面

    Ubuntu Server 22 04 安装桌面 sudo apt get update sudo apt get upgrade apt get install y ubuntu desktop 如果你不想安装一些附加的程序 可用以下命令
  • Tomcat startup.bat Using CATALINA_OPTS: ““,启动闪退

    Tomcat 控制台打开startup bat 发现Tomcat终端窗口闪退打不开 在startup bat最后加一个pause 会弹出Using CATALINA OPTS 重新设置了JAVA HOME 没用 在startup bat最前
  • python对csv文档进行读,写,追加操作(csv,pandas)

    python处理csv文档的两种方法 csv pandas python处理csv一般采用两种方法一种是import pandas 另一种是 import csv 本文将介绍这两种方法对csv进行读 写 追加的操作 1 import pan
  • mysql unique key使用_mysql unique key在查询中的使用与相关问题

    1 建表语句 CREATE TABLE employees emp no int 11 NOT NULL birth date date NOT NULL first name varchar 14 NOT NULL last name v
  • WordPress多站点伪静态的设置方式(WP站群必备)

    apache的设置方式 正常参考wordpress自带的设置方式即可 本文只介绍Ngnix及IIS下Wordpress多站点伪静态设置方式 Ngnix的wordpress多站点伪静态设置方式 wordpress固定链接设置 try file
  • Arduino for ESP8266&ESP32适用库ESPAsyncWebServer:快速入门

    文章目录 目的 特征 安装 快速体验 注意事项 总结 目的 Arduino for ESP8266 和 Arduino for ESP32 中默认就有WebServer 不过这些WebServer都是同步的 不支持同时处理多个连接 这在很多
  • AngularJs单元测试

    这篇文章主要介绍了angularJS中的单元测试实例 本文主要介绍利用Karma和Jasmine来进行ng模块的单元测试 并用Istanbul 来生成代码覆盖率测试报告 需要的朋友们可以参考下 以下可全都是干货哦 当ng项目越来越大的时候
  • 如何在Insert插入操作之后,获取自增主键的ID值

    背景说明 MyBatis中 在大多数情况下 我们向数据库中插入一条数据之后 并不需要关注这条新插入数据的主键ID 我们也知道 正常在DAO中的插入语句虽然可以返回一个int类型的值 但是这个值表示的是插入影响的行数 而不是新插入数据的主键I
  • Unity3D如何修改Button显示的文字以及深入了解Button组件

    在创建了一个Button后 结构如图 先仔细观察一下Button的Inspector视图 发现其中竟然有一个叫Button的脚本组件 新建脚本 代码如下 并将该脚本绑定给Canvas组件 using UnityEngine UI using
  • winform记录

    SpeechSynthesizer 下边代码多次调用 会导致内存溢出outofmemory SpeechSynthesizer 需要改为全局静态 private void button Click object sender EventAr
  • 【MySQL笔记】MySQL8新特性 — 计算列

    什么叫计算列呢 简单来说就是某一列的值是通过别的列计算得来的 例如 a 列值为 1 b 列值为 2 c 列不需要手动插入 定义 a b 的结果为 c 的值 那么 c 就是计算列 是通过别的列计算得来的 在 MySQL 8 中 CREATE
  • Pytorch Torch.utils.data.Sampler

    对于 可迭代样式的数据集 数据加载顺序完全由用户定义的可迭代样式控制 这样可以更轻松地实现块读取和动态批处理大小 例如 通过每次生成一个批处理的样本 为了从数据集中读取数据 pytorch提供了Sampler基类与多个子类实现不同方式的数据