pytorch载入数据与对应的标签,使用torch.utils.data详解,DataLoader的使用

2023-10-28

在进行深度学习处理的时候,我们需要将数据输入到神经网络中进行训练,训练网络的学习能力,其实是根据一定的规则更新网络节点中的参数,而这个规则的来源就是依赖于数据与标签。我们需要将数据与标签相匹配,才能让网络进行训练,比如说网络学习到了一定的特征,而查阅此时的标签信息,比如说是车,那么网络就可以记住这样的特征表示的是车。这就要求我们输入的数据与数据标签是要对应的,在pytorch中,我们使用torch.utils.data 类来实现。

函数的中文文档:

torch.uutils.datahttps://pytorch-cn.readthedocs.io/zh/latest/package_references/data/函数的torch官网文档:torch.uutils.datahttps://pytorch.org/docs/stable/data.html

PyTorch数据加载实用程序的核心是torch.utils.data.DataLoader类。它表示对数据集的Python可迭代,支持

1.地图样式和可迭代样式的数据集,

2.自定义数据的加在顺序 ,

3.自动进程,

4.单线程和多线程数据加载,

5.自动内存固定。

这些都由DataLoader的构造函数参数配置:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

dataset:数据存储的地址

batch_size:每次处理的数据批量大小,一般为2的次方,如2,4,8,16,32,64等等

shuffle:是否随机读入数据,在训练集的时候一般随机读入,在验证集的时候一般不随机读入。

num_works:多线程传入数据,设置的数字即使传入的线程数,可以加快数据的读取。

其余的参数一般不做设置,除非你是炼丹大师,可以自己去尝试一下。

pin_memory:内存是否固定:对于数据加载,传递给DataLoader会自动将提取的数据张量放入固定内存中,从而更快地将数据传输到支持 CUDA 的 GPU。默认内存固定逻辑仅识别张量以及包含张量的映射和可迭代对象。默认情况下,如果固定逻辑看到自定义类型的批处理(如果有返回自定义批处理类型的批处理,则会发生这种情况),或者如果批的每个元素都是自定义类型,则固定逻辑将无法识别它们,并且它将返回该批处理(或这些元素),而不固定内存。若要为自定义批处理或数据类型启用内存固定,请在自定义类型上定义一个方法。collate_fnpin_memory()

prefetch_factor:预加载数据。每个工作线程提前加载的样本数。 意味着所有工人总共将有2 * num_workers个样本预取.

Persisitent_workers:如果 ,则数据加载程序在数据集使用一次后不会关闭工作进程。这允许保持工作线程数据集实例处于活动状态。

drop_last:如果数据集的大小不能被批大小整除,则设置为删除最后一个未完成的批处理。如果数据集的大小不能被批整除,则最后一个批将更小。

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())

DataLoader是数据的交在程序,将数据集和采样器组合在一起,并提供对给定数据集的可迭代。DataLoader支持地图样式和可迭代样式的数据集,具有单进程和多进程加载,自定义加载顺序以及可选的自动批处理和内存固定功能。

torch.utils.data.Dataset(*args, **kwds)

表示的是Dataset的抽象类,所有其他数据都应该进行子类化。

表示从键到数据样本的映射的所有数据集都应对其进行子类化。所有子类都应该覆盖,支持为给定的键获取数据样本。子类也可以选择覆盖 ,这有望通过许多 Sampler实现和 DataLoader的默认选项返回数据集的大小。__getitem__()__len__()

默认情况下,DataLoader构造一个生成整数索引的索引采样器。要使其适应具有非整数索引的地图样式数据集,必须自己定采样器。

torch.utils.data.IterableDataset(*args, **kwds)

可迭代数据集。

表示数据样本可迭代的所有数据集都应对其进行子类化。当数据来自流时,这种形式的数据集特别有用。

所有子类都应该覆盖 ,这将返回此数据集中样本的迭代器。__iter__()

当子类与 DataLoader一起使用时,数据集中的每个项都将从 DataLoader迭代器生成。当 时,每个工作进程将具有数据集对象的不同副本,因此通常需要单独配置每个副本,以避免从工作线程返回重复数据。在工作进程中调用get_worker_info()时,将返回有关工作进程的信息。它可以在数据集的方法或DataLoader的选项中使用,以修改每个副本的行为。num_workers > 0__iter__()worker_init_fn

示例 1:在以下位置的所有工作线程之间拆分工作负载:__iter__()

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code only works with end >= start"
        self.start = start
        self.end = end
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            iter_start = self.start
            iter_end = self.end
        else:  # in a worker process
            # split workload
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        return iter(range(iter_start, iter_end))
# should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
ds = MyIterableDataset(start=3, end=7)

# Single-process loading
print(list(torch.utils.data.DataLoader(ds, num_workers=0)))

# Mult-process loading with two worker processes
# Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
print(list(torch.utils.data.DataLoader(ds, num_workers=2)))

# With even more workers
print(list(torch.utils.data.DataLoader(ds, num_workers=20)))

示例 2:使用 :worker_init_fn

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code only works with end >= start"
        self.start = start
        self.end = end
    def __iter__(self):
        return iter(range(self.start, self.end))
# should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
ds = MyIterableDataset(start=3, end=7)

# Single-process loading
print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
# Directly doing multi-process loading yields duplicate data
print(list(torch.utils.data.DataLoader(ds, num_workers=2)))

# Define a `worker_init_fn` that configures each dataset copy differently
def worker_init_fn(worker_id):
    worker_info = torch.utils.data.get_worker_info()
    dataset = worker_info.dataset  # the dataset copy in this worker process
    overall_start = dataset.start
    overall_end = dataset.end
    # configure the dataset to only process the split workload
    per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
    worker_id = worker_info.id
    dataset.start = overall_start + worker_id * per_worker
    dataset.end = min(dataset.start + per_worker, overall_end)

# Mult-process loading with the custom `worker_init_fn`
# Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))

# With even more workers
print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))

在实际的应用中,如果我们的数据存储方式如下:

每个文件夹的名字就是数据的类名,那么数据可以使用一下方式载入:

使用datasets.ImageFolder读取数据,root是文件存储路径, transform是数据的处理方式,如裁剪,缩放,归一化等等。

train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])

读取到数据总数有6149个,一共有102个文件夹,即分为102个类,每个类的标签从0开始,即数据标签中的前几张图片都是一个类的,且类的名字是0类。 

计算train_dataset的个数:

train_num = len(train_dataset)

 载入数据到 train_loader。

 train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

这样便将数据与标签都存储到了 train_loader中。

这是 DataLoader中的一些参数设置,没有设置的参数都是默认参数。 DataLoader主要是提供一个迭代器。

然后在使用数据的时候,我们将数据加载到进度条中,然后进行使用:

train_bar = tqdm(train_loader)

 然后进行数据正向传播,损失计算,损失反向传播,优化器迭代,损失计算:

        for step, data in enumerate(train_bar): 
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

这样就完成了数据与标签的读入。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch载入数据与对应的标签,使用torch.utils.data详解,DataLoader的使用 的相关文章

  • PyQt:如何通过匿名代理使用网页

    这真让我抓狂 我想在 QWebPage 中显示一个 url 但我想通过匿名代理来实现 Code setting up the proxy proxy QNetworkProxy proxy setHostName 189 75 98 199
  • Python从int到string的快速转换

    我正在用 python 求解大量阶乘 并发现当我完成计算阶乘时 需要相同的时间才能转换为字符串以保存到文件中 我试图找到一种将 int 转换为字符串的快速方法 我将举一个计算和 int 转换时间的例子 我正在使用通用的 a str a 但感
  • Tweepy StreamListener 到 CSV

    我是 python 新手 我正在尝试开发一个应用程序 使用 Tweepy 和 Streaming API 从 Twitter 检索数据并将数据转换为 CSV 文件 问题是此代码不会创建输出 CSV 文件 也许是因为我应该将代码设置为在实现例
  • WindowsError:[错误 126] 使用 ctypes 加载操作系统时

    python代码无法在Windows 7平台上运行 def libSO lib ctypes cdll LoadLibrary ConsoleApplication2 so lib cfoo2 1 3 当我尝试运行它时 得到来自python
  • 使用 scipy curve_fit 拟合噪声指数的建议?

    我正在尝试拟合通常按以下方式建模的数据 def fit eq x a b c d e return a 1 np exp x b c np exp x d e x np arange 0 100 0 001 y fit eq x 1 1 1
  • PySide6.1 与 matplotlib 3.4 不兼容

    当我只安装PySide6时 GUI程序运行良好 但是一旦我安装了matplotlib及其依赖包 包括pyqt5 则GUI程序将无法运行并输出以下错误消息 This application failed to start because no
  • 动态 __init_subclass__ 方法的参数绑定

    我正在尝试让类装饰器工作 装饰器会添加一个 init subclass 方法到它所应用的类 但是 当该方法动态添加到类中时 第一个参数不会绑定到子类对象 为什么会发生这种情况 举个例子 这是可行的 下面的静态代码是我试图最终得到的示例 cl
  • 在Python中计算内存碎片

    我有一个长时间运行的进程 不断分配和释放对象 尽管正在释放对象 但 RSS 内存使用量会随着时间的推移而增加 如何计算发生了多少碎片 一种可能性是计算 RSS sum of allocations 并将其作为指标 即便如此 我该如何计算分母
  • Python多处理错误“ForkAwareLocal”对象没有属性“连接”

    下面是我的代码 我面临着多处理问题 我看到这个问题之前已经被问过 我已经尝试过这些解决方案 但它似乎不起作用 有人可以帮我吗 from multiprocessing import Pool Manager Class X def init
  • 乘以行并按单元格值附加到数据框

    考虑以下数据框 df pd DataFrame X a b c d Y a b d e Z a b c d 1 2 1 3 df 我想在 列中附加数字大于 1 的行 并在该行中的数字减 1 df 最好应该 然后看起来像这样 或者它可能看起来
  • 具有屏蔽无效值的 pcolormesh

    我试图将一维数组绘制为 pcolormesh 因此颜色沿 x 轴变化 但每个 x 的 y 轴保持不变 但我的数据有一些错误值 因此我使用屏蔽数组和自定义颜色图 其中屏蔽值设置为蓝色 import numpy as np import mat
  • 将文本注释到轴并对齐为圆

    我正在尝试在轴上绘制文本并将该文本与圆对齐 更准确地说 有一些具有不同坐标 x y 的点位于该圆内 并使用以下命令创建 ax scatter x y s 100 我想用圆圈连接并标记每个点 Cnameb 文本的坐标由 xp yp 定义 因此
  • Python]将两个文本文件合并为一个(逐行)[关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 我是蟒蛇新手 我想做的是将文件 a 和文件 b 逐行合并到一个文件中 例如 text file a a n b n c text fi
  • 无法在 python 3.8 上将带有 webapp 的 python 部署到 azure

    我正在尝试使用部署一个测试项目Flask使用以下方法将框架迁移到 Azure 云中Azure CLI https learn microsoft com en us azure app service containers quicksta
  • 检测 IDLE 的存在/如何判断 __file__ 是否未设置

    我有一个脚本需要使用 file 所以我了解到 IDLE 没有设置这个 有没有办法从我的脚本中检测到 IDLE 的存在 if file not in globals file is not set 如果你想做一些特别的事情 file 未设置
  • Python 通过从现有 csv 文件中过滤选定的行来写入新的 csv 文件

    只是一个问题 我试图将 csv 文件中的选定行写入新的 csv 文件 但出现错误 我试图读取的 test csv 文件是这样的 两列 2013 9 1 2013 10 2 2013 11 3 2013 12 4 2014 1 5 2014
  • Django 模型:如何使用 mixin 类来覆盖 django 模型以实现 save 等功能

    我想在每次保存模型之前验证值 所以 我必须重写保存函数 代码几乎是一样的 我想把它写在 mixin 类中 但失败了 我不知道如何写 super func 我英语不好 抱歉 class SyncableMixin object def sav
  • 在 Django shell 会话期间获取 SQL 查询计数

    有没有办法打印 Django ORM 在 Django shell 会话期间执行的原始 SQL 查询的数量 Django 调试工具栏已经提供了此类信息 例如 5 QUERIES in 5 83MS但如何从 shell 中获取它并不明显 您可
  • python sklearn中的fit方法

    我问自己关于 sklearn 中拟合方法的各种问题 问题1 当我这样做时 from sklearn decomposition import TruncatedSVD model TruncatedSVD svd 1 model fit X
  • 如何获取所有mysql元组结果并转换为json

    我能够从表中获取单个数据 但是当我试图获取表上的所有数据时 我只得到一行 cnn execute sql rows cnn fetchall column t 0 for t in cnn description for row in ro

随机推荐