Pytorch : Dataset和DataLoader

2023-05-16

一、综述

Dataset :对数据进行抽象,将数据包装为Dataset类。
DataLoader:在 Dataset之上对数据进行进一步处理,包括进行乱序处理,获取一个batch_size的数据等。
在这里插入图片描述

二、Dataset

在Dataset类中必须重新 getitem()、len()两个方法。

  1. 创建数据
ss=np.linspace(1,100,100)
np.savetxt("sample_data.txt", ss.reshape(-1,4))

数据格式如下所示:
在这里插入图片描述
2. 创建自定义Dataset

import numpy as np
import torch as t
from torch.utils.data import Dataset

class MyDataSet(Dataset):
    def __init__(self):
        
        #使用numy读取数据
        txt_data = np.loadtxt('sample_data.txt')
        #取数据前三列为x
        self._x = t.from_numpy(txt_data[:,:3])
        #取数据最后一列为target值
        self._y = t.from_numpy(txt_data[:,-1])
        #获取数据的长度
        self._len = len(txt_data)
        
    def __getitem__(self,item):
        #item对应的一条数据,可以是一张图,可以是一句话,总之 记住,一条数据。
        return self._x[item],self._y[item]
    
    def __len__(self):
        #带训练数据的总长度, 如果是dataframe, 直接len(df)即可,或者在init的时候传入了长度,直接返回
        return self._len

dataset =  MyDataSet()
print(len(dataset))
data =next(iter(dataset))
print(data)

在这里插入图片描述

三、 DataLoader

在这里插入图片描述

关键参数:

  • dataset :数据集
  • batch_size : 一个批次的大小
  • shuffle : 是否乱序处理
  • sampler:非常简单的多线程方法, 只要设置为>=1, 就可以多线程预读数据啦.
  • drop_last:如果数据集大小不能整除batch_size的话,是否删除最后一个batch
from torch.utils.data import DataLoader

data = MyDataSet()
dataloader = DataLoader(data,batch_size=4,shuffle=True,drop_last=True,num_workers=0)

for i,data in enumerate(dataloader):
    print('batch---->',i+1)
    inputs,labels=data
    print(inputs)
    print(labels)
    print("*"*30)

在这里插入图片描述

四、random_split

pytorch中 random_split类似于 sklearn中的train_test_split类似的功能,将数据切分为训练集、测试集、验证集。

from torch.utils.data import random_split

all_length =len(dataset)
train_size =int(0.8*all_length)
test_size = all_length - train_size

#切分数据集
train_dataset,test_dataset = random_split(dataset,[train_size,test_size])
train_loader = DataLoader(train_dataset, batch_size=3, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0)
for i,curr_data in enumerate(train_loader):
    print('batch---->',i+1)
    inputs,labels=curr_data
    print(inputs)
    print(labels)
    print("*"*30)
    ```
 ![在这里插入图片描述](https://img-blog.csdnimg.cn/2021012612065338.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0dhb3dhaGFoYQ==,size_16,color_FFFFFF,t_70)
   
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch : Dataset和DataLoader 的相关文章

  • 将 CNN Pytorch 中的预训练权重传递到 Tensorflow 中的 CNN

    我在 Pytorch 中针对 224x224 大小的图像和 4 个类别训练了这个网络 class CustomConvNet nn Module def init self num classes super CustomConvNet s
  • 如何使用宏引用数据文件?

    我有各种 Stata 数据文件 它们位于不同的文件夹中 我也有一个单do使用这些文件的文件 一次一个 有没有办法使用宏来引用我的特定数据集do file 例如 local datafile C filepath mydata dta 我们的
  • 谷歌 Colab 上的 RVL-CDIP 数据集

    我正在尝试使用以下命令在 google colab 上下载 RVL CDIP 数据集 wget load cookies tmp cookies txt https docs google com uc export download co
  • torch.mm、torch.matmul 和 torch.mul 有什么区别?

    阅读完 pytorch 文档后 我仍然需要帮助来理解之间的区别torch mm torch matmul and torch mul 由于我不完全理解它们 所以我无法简明地解释这一点 B torch tensor 1 1207 0 3137
  • LSTM 错误:AttributeError:“tuple”对象没有属性“dim”

    我有以下代码 import torch import torch nn as nn model nn Sequential nn LSTM 300 300 nn Linear 300 100 nn ReLU nn Linear 300 7
  • 如何命名在存储过程中返回的数据集的表?

    我有以下存储过程 Create procedure psfoo AS select from tbA select from tbB 然后我以这种方式访问 数据 Sql Command mySqlCommand new SqlCommand
  • 尝试理解 Pytorch 的 LSTM 实现

    我有一个包含 1000 个示例的数据集 其中每个示例都有5特征 a b c d e 我想喂7LSTM 的示例 以便它预测第 8 天的特征 a 阅读 nn LSTM 的 Pytorchs 文档 我得出以下结论 input size 5 hid
  • PyTorch 中复数矩阵的行列式

    有没有办法在 PyTorch 中计算复矩阵的行列式 torch det未针对 ComplexFloat 实现 不幸的是 目前尚未实施 一种方法是实现您自己的版本或简单地使用np linalg det 这是一个简短的函数 它计算我使用 LU
  • 为什么 pytorch matmul 在 cpu 和 gpu 上执行时得到不同的结果?

    我试图找出 numpy pytorch gpu cpu float16 float32 数字之间的舍入差异 而我发现的内容让我感到困惑 基本版本是 a torch rand 3 4 dtype torch float32 b torch r
  • BatchNorm 动量约定 PyTorch

    Is the 批归一化动量约定 http pytorch org docs master modules torch nn modules batchnorm html 默认 0 1 与其他库一样正确 例如Tensorflow默认情况下似乎
  • 根据条件过滤数据集

    我正在使用 asp net 2 0 和 c 我有一个数据集 正在获取员工信息 现在我想根据用户在搜索文本框中输入的名称来过滤网格视图 我正在这样做 DataSet ds new DataSet EmployeeInformation loa
  • 数据源和数据集的区别

    我目前正在开发一个项目 其主要任务是读取存储在 SQL 数据库中的数据并以用户友好的形式显示它们 使用的编程语言是C 我在 Borland C Builder 6 环境中工作 但我认为标题中提出的问题与编程语言或库无关 当从数据库读取数据时
  • Pytorch GPU 使用率低

    我正在尝试 pytorch 的例子https pytorch org tutorials beginner blitz cifar10 tutorial html https pytorch org tutorials beginner b
  • pytorch 的 IDE 自动完成

    我正在使用 Visual Studio 代码 最近尝试了风筝 这两者似乎都没有 pytorch 的自动完成功能 这些工具可以吗 如果没有 有人可以推荐一个可以的编辑器吗 谢谢你 使用Pycharmhttps www jetbrains co
  • PyTorch 中的连接张量

    我有一个张量叫做data形状的 128 4 150 150 其中 128 是批量大小 4 是通道数 最后 2 个维度是高度和宽度 我有另一个张量叫做fake形状的 128 1 150 150 我想放弃最后一个list array从第 2 维
  • 用于神经网络模型预测的数据的缺失值

    我目前有大量数据将用于训练预测神经网络 美国主要机场的千兆字节天气数据 我几乎每天都有数据 但有些机场的数据中存在缺失值 例如 机场在 1995 年之前可能不存在 因此在此之前我没有该特定位置的数据 此外 有些还缺少整年 可能跨度为 199
  • 样本()和r样本()有什么区别?

    当我从 PyTorch 中的发行版中采样时 两者sample and rsample似乎给出了类似的结果 import torch seaborn as sns x torch distributions Normal torch tens
  • OSError: [Errno 22] 当我尝试 .read() json 文件时

    我只是想用 Python 读取我的 json 文件 当我这样做时 我位于正确的文件夹中 我在 下载 中 我的文件名为 Books 5 json 但是 当我尝试使用 read 函数时 出现错误 OSError Errno 22 Invalid
  • 将 Pytorch LSTM 的状态参数转换为 Keras LSTM

    我试图将现有的经过训练的 PyTorch 模型移植到 Keras 中 在移植过程中 我陷入了LSTM层 LSTM 网络的 Keras 实现似乎具有三种状态类型的状态矩阵 而 Pytorch 实现则具有四种状态矩阵 例如 对于hidden l
  • PyTorch:如何批量进行推理(并行推理)

    如何在PyTorch中批量进行推理 如何并行进行推理以加快这部分代码的速度 我从进行推理的标准方法开始 with torch no grad for inputs labels in dataloader predict inputs in

随机推荐

  • Mac上使用clion基于cmake 开发gtk gtk+

    gtk gtk 43 度娘介绍 xff1a GTK 43 xff08 GIMP Toolkit 是一套源码以LGPL许可协议分发 跨平台的图形工具包 最初是为GIMP写的 xff0c 已成为一个功能强大 设计灵活的一个通用图形库 xff0c
  • Mac上使用clion基于cmake 开发qt

    安装软件 清华 在线安装的qt安装器 https mirror tuna tsinghua edu cn qt archive online installers 3 0 里面选择项如下 xff1a 红色区域必选一项 xff0c 紫色区域自
  • docker make

    FROM debian 10 RUN apt get update amp amp apt get install y no install recommends curl python3 vim python3 distutils sql
  • U-SEM体验模型——让游戏交互设计的维度更加清晰

    U SEM体验模型 让游戏交互设计的维度更加清晰 U SEM体验模型 让游戏交互设计的维度更加清晰游戏交互设计的定位游戏交互设计的场景游戏对玩家输出玩家对游戏输入 游戏交互体验的维度 U SEM体验模型游戏交互的复杂度游戏交互设计的应用 游
  • Android 12 WiFi 连接状态轮转

    WiFi 学习资料整理 gt nbsp nbsp Android WiFi 目录 WiFi 学习资料整理 gt nbsp nbsp Android WiFi 1 nbsp WifiClientModeImpl和WPA Supplicant状
  • 扩展卡尔曼滤波(EKF)算法详细推导及仿真(Matlab)

    扩展卡尔曼滤波 xff08 EKF xff09 算法详细推导及仿真 xff08 Matlab xff09 扩展卡尔曼滤波算法是解决非线性状态估计问题最为直接的一种处理方法 xff0c 尽管EKF不是最精确的 最优 滤波器 xff0c 但在过
  • Linux中线程的同步与互斥、生产者消费模型和读者写者问题、死锁问题

    线程的同步与互斥 线程是一个存在进程中的一个执行控制流 xff0c 因为线程没有进程的独立性 xff0c 在进程内部线程的大部分资源数据都是共享的 xff0c 所以在使用的过程中就需要考虑到线程的安全和数据的可靠 不能因为线程之间资源的竞争
  • 解决The following packages have unmet dependencies问题!!!

    1 安装包的时候出现如下情况 xff08 缺少依赖 xff09 xff1a 2 解决方案 xff1a 尝试多种方法无果 xff0c 最终借助一个强大的包管理工具 xff08 aptitude xff09 终于成功了 xff01 xff01
  • linux的开机过程

    1 主机加电自检 xff0c 加载 BIOS 硬件信息 2 读取 MBR 的引导文件 GRUB LILO 3 引导 Linux 内核 4 运行第一个进程 init 进程号永远为 1 5 进入相应的运行级别 6 运行终端 xff0c 输入用户
  • Ubunt 20.04 使用CDROM或ISO作为安装源

    有些项目由于安全性的要求 xff0c 需要部署在没有互联网环境的内网中 xff0c 那么如何在离线环境中给ubuntu安装相关的软件就是考验大家的linux基础知识的时候了 本文就带领大家利用CDROM或者挂载ISO镜像两种方式配置ubun
  • 关于Intellij idea 报错:Error : java 不支持发行版本5的问题

    在Intellij idea中新建了一个Maven项目 xff0c 运行时报错如下 xff1a Error java 不支持发行版本5 本地运行用的是JDK9 xff0c 测试Java的Stream操作 xff0c 报错应该是项目编译配置使
  • Spring之配置类源码深度解析

    这篇文章是继 Spring之启动过程源码解析之后 xff0c 对Spring启动过程中用到的几个重要的方法进行详细的解读 目录 一 invokeBeanFactoryPostProcessors xff0c 执行BeanFactoryPos
  • 20210702剑指Offer03(数组中重复数字)

    找出数组中重复的数字 输入 xff1a 2 3 1 0 2 5 3 输出 xff1a 2 或 3 span class token keyword class span span class token class name Solutio
  • react异步数据如ajax请求应该放在哪个生命周期?

    对于同步的状态改变 xff0c 是可以放在componentWillMount xff0c 对于异步的 xff0c 最好好放在componentDidMount 但如果此时有若干细节需要处理 xff0c 比如你的组件需要渲染子组件 xff0
  • RabbitMQ exchange交换机机制

    目录 RabbitMQ 概念exchange交换机机制 什么是交换机binding xff1f Direct Exchange交换机Topic Exchange交换机Fanout Exchange交换机Header Exchange交换机R
  • 解决open-vm-tools无法复制粘贴文件问题

    在使用vmware kali linux时一直忍受着一个情况 xff1a open vm tools Error when getting information for file 34 tmp VMwareDnD 3jTONh xxx N
  • mipmap 和 drawable 的区别

    Android 在 API level 17 加入了 mipmap 技术 xff0c 对 bitmap 图片的渲染支持 mipmap 技术 xff0c 来提高渲染的速度和质量 mipmap 是一种很早就有的技术了 xff0c 翻译过来就是纹
  • LSTM与GRU

    LSTM 与 GRU 一 综述 LSTM 与 GRU是RNN的变种 xff0c 由于RNN存在梯度消失或梯度爆炸的问题 xff0c 所以RNN很难将信息从较早的时间步传送到后面的时间步 LSTM和GRU引入门 xff08 gate xff0
  • Pytorch 实战RNN

    一 简单实例 span class token comment coding utf8 span span class token keyword import span torch span class token keyword as
  • Pytorch : Dataset和DataLoader

    一 综述 Dataset 对数据进行抽象 xff0c 将数据包装为Dataset类 DataLoader 在 Dataset之上对数据进行进一步处理 xff0c 包括进行乱序处理 xff0c 获取一个batch size的数据等 二 Dat