详细介绍torch中的from torch.utils.data.sampler相关知识

2023-12-04

PyTorch中的 torch.utils.data.sampler 模块提供了一些用于数据采样的类和函数,这些类和函数可以用于控制如何从数据集中选择样本。下面是一些常用的 Sampler 类和函数的介绍:

  1. Sampler 基类: Sampler 是一个抽象类,它定义了一个 __iter__ 方法,返回一个迭代器,用于生成数据集中的样本索引。
  2. RandomSampler : 随机采样器,它会随机从数据集中选择样本。可以设置随机数种子,以确保每次采样结果相同。
  3. SequentialSampler : 顺序采样器,它会按照数据集中的顺序,依次选择样本。
  4. SubsetRandomSampler : 子集随机采样器,它会从数据集的指定子集中随机选择样本。可以用于将数据集分成训练集和验证集等子集。
  5. WeightedRandomSampler : 加权随机采样器,它会根据指定的样本权重,进行随机采样。可以用于处理类别不平衡的问题。
  6. BatchSampler : 批次采样器,它会将样本索引分成多个批次,每个批次包含指定数量的样本索引。

这些 Sampler 类可以通过在 DataLoader 的构造函数中指定来使用。例如,可以使用 RandomSampler 来实现随机采样,使用 SubsetRandomSampler 来实现将数据集分成训练集和验证集。此外,还可以使用函数如 WeightedRandomSampler 来实现加权随机采样。

下面是使用上述 Sampler 类和函数的示例代码:

import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler, SubsetRandomSampler, WeightedRandomSampler

# 创建一个数据集
dataset = torch.utils.data.TensorDataset(torch.randn(10, 3), torch.randint(0, 2, (10,)))

# 创建一个使用RandomSampler的DataLoader
random_loader = DataLoader(dataset, batch_size=2, sampler=RandomSampler(dataset))

# 创建一个使用SequentialSampler的DataLoader
seq_loader = DataLoader(dataset, batch_size=2, sampler=SequentialSampler(dataset))

# 创建一个使用SubsetRandomSampler的DataLoader
train_indices = [0, 1, 2, 3, 4]
val_indices = [5, 6, 7, 8, 9]
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
train_loader = DataLoader(dataset, batch_size=2, sampler=train_sampler)
val_loader = DataLoader(dataset, batch_size=2, sampler=val_sampler)

# 创建一个使用WeightedRandomSampler的DataLoader
weights = [0.1, 0.9]
weighted_sampler = WeightedRandomSampler(weights, num_samples=10, replacement=True)
weighted_loader = DataLoader(dataset, batch_size=2, sampler=weighted_sampler)

# 使用BatchSampler将样本索引分成多个批次
batch_sampler = torch.utils.data.sampler.BatchSampler(SequentialSampler(dataset), batch_size=2, drop_last=False)
batch_loader = DataLoader(dataset, batch_sampler=batch_sampler)

# 遍历DataLoader,输出每个批次的数据
for data, label in random_loader:
    print(data, label)
    
for data, label in seq_loader:
    print(data, label)
    
for data, label in train_loader:
    print(data, label)
    
for data, label in val_loader:
    print(data, label)
    
for data, label in weighted_loader:
    print(data, label)
    
for batch_indices in batch_sampler:
    batch_data = [dataset[idx] for idx in batch_indices]
    print(batch_data)

在这个示例中,我们首先创建了一个包含10个样本的 TensorDataset 。然后,我们创建了5个不同的 DataLoader ,每个 DataLoader 使用不同的采样器(RandomSampler、SequentialSampler、SubsetRandomSampler、WeightedRandomSampler、BatchSampler)来从数据集中选择样本。最后,我们遍历这些 DataLoader ,输出每个批次的数据。

可以通过继承 Sampler 基类来自定义采样函数。自定义采样函数需要实现 __iter__ 方法和 __len__ 方法。

__iter__ 方法需要返回一个迭代器,迭代器的每个元素都是数据集中的一个样本的索引。在这个方法中,可以自定义样本索引的选取方式,例如根据某种规则筛选样本或者将数据集分成多个子集。

__len__ 方法需要返回采样器的样本数量。如果采样器使用的是数据集的全部样本,则返回数据集的长度。

下面是一个自定义采样器的示例代码:

import torch
from torch.utils.data.sampler import Sampler

class CustomSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source
        # 在初始化方法中,可以根据需要对数据集进行处理
    
    def __iter__(self):
        # 在这个方法中,可以自定义样本索引的选取方式
        # 这里的示例是随机选取样本
        indices = torch.randperm(len(self.data_source)).tolist()
        return iter(indices)
    
    def __len__(self):
        # 在这个方法中,需要返回采样器的样本数量
        # 这里的示例是采样器的样本数量等于数据集的长度
        return len(self.data_source)

在这个示例中,我们定义了一个名为 CustomSampler 的采样器类,它继承自 Sampler 基类。在初始化方法中,我们保存了数据集,并可以根据需要对数据集进行处理。在 __iter__ 方法中,我们自定义了样本索引的选取方式,这里的示例是随机选取样本。在 __len__ 方法中,我们返回了采样器的样本数量,这里的示例是采样器的样本数量等于数据集的长度。

使用自定义采样器时,只需要将它传入 DataLoader 的构造函数即可:

dataset = torch.utils.data.TensorDataset(torch.randn(10, 3), torch.randint(0, 2, (10,)))
custom_sampler = CustomSampler(dataset)
loader = DataLoader(dataset, batch_size=2, sampler=custom_sampler)

在这个示例中,我们首先创建了一个包含10个样本的 TensorDataset 。然后,我们使用 CustomSampler 创建了一个采样器,并将它传入 DataLoader 的构造函数。最后,我们遍历这个 DataLoader ,输出每个批次的数据。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

详细介绍torch中的from torch.utils.data.sampler相关知识 的相关文章

随机推荐

  • 面试简历的最后一道坎,实战项目经验详解

    日常猫猫缓解气氛 说起面试 实战项目经验一定是面试官问的重中之重 原因无外乎以下几点 一 面试官问项目经验的目的 通过你做的项目来判断你的专业技能 资历段位 成绩表现与简历或自我介绍中描述的是否一致 通过你对项目细节的描述 看看你是否能够独
  • 淘宝天猫商品评论采集,用rpa机器人轻松解决!

    电商行业是目前发展非常迅速的行业 淘宝天猫作为国内最大的电商平台之一 商品评论对于商家来说非常重要 商品评论可以反映出产品的好坏和用户的购买体验 是用户决策的重要参考因素 商品评论的采集对于商家来说非常重要 然而 手动采集大量评论数据耗时耗
  • 钛氧物种与钴相互作用-科学指南针

    中科院与上海交通大学合作 在碳化物作为载体的钴基费托合成研究中取得新进展 借助透射电子显微镜等技术 揭示了还原过程中碳化钛表面的钛氧物种到金属钴表面的原位迁移现象 这种增强的金属 载体的相互作用促进了费托合成反应活性 通过透射电子显微镜可以
  • 人工智能与大数据专业毕设选题汇总 最新版

    目录 前言 毕设选题 选题迷茫 选题的重要性 更多选题指导 最后 前言 大四是整个大学期间最忙碌的时光 一边要忙着备考或实习为毕业后面临的就业升学做准备 一边要为毕业设计耗费大量精力 近几年各个学校要求的毕设项目越来越难 有不少课题是研究生
  • cuda 在 torch神经网络中哪些地方可以用?

    简言之 3部分 1 数据data可以放在GPU上 2 网络net可以放在GPU上 3 损失函数loss可以放在GPU上 CUDA可以用于在torch神经网络中进行GPU加速计算 包括模型的训练和推理过程 具体来说 可以使用CUDA加速以下操
  • 一个简单的参数帮助框架,c实现

    文章目录 具体实现如下 include
  • RUM增强APP端快照配置全量会话回放与自定义协议网络请求采集功能

    一直以来 博睿数据秉承着 让每一款软件运行更完美 的产品理念 注重用户体验和反馈 以持续的技术创新 为广大用户提供轻盈 有序 精准的IT运维一体化智能可观测平台 降低运维成本 近期 博睿数据根据一体化智能可观测平台 Bonree ONE 产
  • 牛掰!《鸿蒙零基础入门学习指南》重磅来袭

    前言 不久前 华为开发者大会2023 宣布不再兼容安卓 同时宣布了 鸿飞计划 余承东承诺再投入超百亿元 以扶持和打造鸿蒙生态 鸿蒙不再兼容安卓 欲与iOS 安卓在市场三分天下 这对中国国产操作系统而言 具有划时代的意义 近期 美团 网易 微
  • Windows下环境配置Cmake、MinGW、OpenCV

    一 安装Cmake 1 选择自己需要下载的版本 下载地址 gt https github com Kitware CMake releases download v3 26 5 cmake 3 26 5 windows x86 64 msi
  • 前阿里P6花七天时间整理地方软件测试基础知识,高手请绕道

    可以说软件测试所学习的知识都是在循序渐进的 从更基础的知识逐渐延伸到困难的知识 由此可以看出 基础知识是这些重难点知识延伸的基础 想要升职加薪 基础知识必须牢靠 一 软件测试概述 1 软件缺陷 软件缺陷 又称之为 Bug 即计算机软件或程序
  • 制造业如何做生产设备管理、分析生产数据?

    本文将为大家讲解 1 设备管理的现状与问题 2 设备管理系统功能 3 制造业企业如何做生产设备管理 分析生产数据 4 制造业设备管理的价值 想要管理好设备 设备档案管理 巡检 报修 保养 分析预警等问题都是必须要考虑的 我们公司正是使用了设
  • 介绍kfold.split()的详细用法

    KFold 是交叉验证中的一种方法 其可以将数据集划分为 K 份 然后使用其中一份作为验证集 剩下的 K 1 份作为训练集 这个过程可以重复 K 次 以便每个子集都被用作验证集 KFold split 是 KFold 类中的一个方法 用于将
  • 黑马一站制造数仓实战1

    1 项目目标 一站制造 企业中项目开发的落地 代码开发 代码开发 SQL DSL SQL SparkCore SparkSQL 数仓的一些实际应用 分层体系 建模实现 2 内容目标 项目业务介绍 背景 需求 项目技术架构 选型 架构 项目环
  • 科技改变生活智能化让生活更便捷

    在科技迅猛发展的时代 我们正处于信息化和智能化的浪潮中 如何善用科技 让生活更加便捷 成为了当代人们共同关心的问题 本文将围绕这一主题 深入探讨科技如何改变我们的日常生活 让生活更智能 更方便 1 科技便捷生活 智能引领未来 这个强调了科技
  • Docker容器安装部署

    阿里云网站 mirrors aliyun com 一 安装步骤 yum源的配置 最好用环境干净的虚拟机进行安装部署 1 在 etc yum repos d中配置 docker repo 并直接配置centos源以免出现依赖性问题 2 直接列
  • electron 应用图标修改

    修改窗口图标 更换Electron应用程序的桌面图标 准备好你想要作为图标的图片文件 可以是PNG格式 安装一个可以转换图片格式为ICO的工具 例如在线转换工具 在线转换icon图标工具 将你的PNG图片文件上传并转换为ICO格式 将转换得
  • LANG、LC_CTYPE、LC_ALL环境变量

    修改编码格式 export LANG zh CN UTF 8 修改所有的编码格式 优先级高 export LC ALL zh CN UTF 8 locale是根据计算机用户所使用的语言 所在国家或者地区 以及当地的文化传统所定义的一个软件运
  • 零束科技:博睿数据是智能化路上的可靠“守护者”

    近年来 汽车市场环境的复杂性上升 全球各类不稳定因素增加 造车新势力挑战不断 车企借助云 容器化 微服务等技术加速自身数字化变革 已经成为面向未来发展的主要趋势 但随着数字化程度不断深入 自有系统的稳定性 性能 瓶颈以及由故障所带来的各类影
  • 两步解决opencsv 设置@CsvBindByPosition(position = 0)导致@CsvBindByName(column = “批次号“) 标题头不写入的问题

    获取实体类中的所有column private static
  • 详细介绍torch中的from torch.utils.data.sampler相关知识

    PyTorch中的 torch utils data sampler 模块提供了一些用于数据采样的类和函数 这些类和函数可以用于控制如何从数据集中选择样本 下面是一些常用的 Sampler 类和函数的介绍 Sampler 基类 Sample