torch.utils.data.dataloader参数collate_fn简析

2023-05-16

torch.utils.data.DataLoader是pytorch提供的数据加载类,初始化函数如下,

torch.utils.data.DataLoader(dataset,batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

dataset,batch_size等参数重要且容易理解,而collate_fn参数就不太直白,官方解释为:

collate_fn (callableoptional) – merges a list of samples to form a mini-batch

不明不白。

其实,collate_fn可理解为函数句柄、指针...或者其他可调用类(实现__call__函数)。 函数输入为list,list中的元素为欲取出的一系列样本。具体如下

indices = next(self.sample_iter)
batch = self.collate_fn([dataset[i] for i in indices])

其中self.sampler_iter即采样器,返回下一个batch中样本的序号,indices。

通过collate_fn函数可以对这些样本做进一步的处理(任何你想要的处理),原则上返回值应当是一个有结构的batch。而DataLoader每次迭代的返回值就是collate_fn的返回值。

以图像关键点训练数据采样举例:

采样器调用我们自定义数据类的__getitem__(self, idx)函数获取训练样本,假设__getitem__函数返回字典:

{
"image": [[...],[...]]#一副图像,tensor,格式1CHW
 "keypoints":[[x1,y1],[x2,y2],...]#图像中的关键点,tensor
}

那么通过sampler采样一个batch的样本时,返回的是一个list,格式如下

[
{"image": [[...],[...]],
 "keypoints":[[x1,y1],[x2,y2],...]},

{"image": [[...],[...]],
 "keypoints":[[x1,y1],[x2,y2],...]}
]

我们知道,神经网络在处理图像数据时,可以一次输入一个batch的数据,格式为(BCHW)的tensor,因此我们需要将数据变成如下格式

{
"images":[[[...]],[[...]]]#多幅图像,Tensor,格式:BCHW
"keypoints":[tensor,tensor]#每个元素都是一个list或tensor,对应与各image中的关键点
}

这个转换过程就可以通过collate_fn函数完成。

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

torch.utils.data.dataloader参数collate_fn简析 的相关文章

  • css data:image/svg+xml 不显示

    原因 xff1a 新版chrome不支持 需要改成 23 如 xff1a test span class token punctuation span content url span class token punctuation spa
  • pytorch函数详解

    pytorch函数详解 在typora这里写之后复制到简书上 1 torchvision 1 1 transforms Compose transforms 把几个转换组合 example from PIL import Image t t
  • Docker部署AI算法教程

    docker上部署算法除了一些推理框架外 有时候会自己用torch推理加上一些web应用 下面写下自己用的一套方法 Docker cuda10 1 miniconda3 torch1 7 1 docker要求19 03以上才支持cuda 1
  • PyTorch torch.optim.lr_scheduler 学习率设置 调参 -- CosineAnnealingLR

    lr scheduler 学习率 学习率的参数调整是深度学习中一个非常重要的一项 Andrew NG 吴恩达 认为一般如果想调参数 第一个一般就是学习率 作者初步学习者 有错误直接提出 热烈欢迎 共同学习 感谢Andrew ng的机器学习和
  • 使用pandas对xlsx文件的基本操作

    起因 因最近实习期间 要求查看 xlsx文件中数据是否有误 由于数据较多 想用python去执行 结果发现网上对xlsx文件操作或是太旧 大多难以应用 所以自己整理了一下 以备自己后用 模拟一个测试数据集data test xlsx文件 文
  • python torch在dataloader处卡死

    torch在dataloader处卡死 1 解决方案 2 调试历程 2 1 网上搜索了很多方法 尝试无果 故亲自调试 2 2 进入函数 发现一段神奇的代码 1 解决方案 num workers设置为0 一般解决大多数问题 修改读取数据部分代
  • 使用Torch nngraph实现LSTM

    什么是RNN RNN 多层反馈RNN Recurrent neural Network 循环神经网络 神经网络是一种节点定向连接成环的人工神经网络 这种网络的内部状态可以展示动态时序行为 不同于前馈神经网络的是 RNN可以利用它内部的记忆来
  • Kafka工具类

    package com cnic utils import org apache flink api common serialization SimpleStringSchema import org apache flink api c
  • map与java bean相互转换

    map与java对象的相互转换 1 使用org apache commons beanutils转换 2 使用Introspector转换 3 使用reflect转换 4 使用net sf cglib beans BeanMap转换 5 使
  • Anaconda3中torch.cuda.is_available()返回false的可能解决办法

    1 问题 在CUDA cudnn 已装好 指令 conda install pytorch torchvision torchaudio pytorch cuda 11 7 c pytorch c nvidia 一直转圈 不得已使用pip指
  • vue中computed的属性对data中的属性赋值为undefined的原因

    场景 我在computed中return了一个值 然后在data中直接将它复制给另一个属性 结果data中的属性值为undefined 代码示例 timer为undefined 原因 在这里很容易想到是执行顺序的问题 computed中的属
  • 在 Windows 上使用 Luarocks 安装 Torch7 并出现 mingw 构建错误

    我按照说明进行操作here并与 Mingw 从头开始 建立 Lua 和 Luarocks 一切工作正常 我能够安装rocks 包括那些需要像LuaSocket这样编译的东西 我按照说明进行操作Torch7通过 luarocks 安装 Tor
  • Sqlite:多次更新(查找和替换)不区分大小写

    我使用 DB Browser for SQLite 来可视化和更新 sqlite 文件 我能够运行区分大小写的查询来更新一些文本 如下所示 UPDATE itemNotes SET note REPLACE note sometext ab
  • 为什么 PyTorch C++ 扩展比其等效的 numba 版本慢得多?

    我一直在尝试各种选项来加速 PyTorch 中的一些 for 循环逻辑 这样做的两个明显的选择是使用numba https stackoverflow com a 75580380 1804173 or 编写自定义 C 扩展 https p
  • Databricks 笔记本挂着 pytorch

    我们遇到 Databricks 笔记本问题 我们的一个笔记本单元似乎挂起 而驱动程序日志确实显示该笔记本单元已被执行 有谁知道为什么我们的笔记本单元一直挂起并且无法完成 请参阅下面的详细信息 情况 我们正在训练 ML 模型pytorch在
  • Lua - 删除非空目录

    我正在尝试删除中的非空目录Lua但没有成功 我尝试了以下方法 os remove path to dir 并得到错误 Directory not empty 39当文件数为39时path to dir 还尝试过 require lfs lf
  • 如何在非 NVIDIA 设置上加速深度学习?

    由于我只有 AMD A10 7850 APU 并且没有资金购买 800 1200 美元的 NVIDIA 显卡 因此我正在尝试利用我拥有的资源通过 TensorFlow Keras 加速深度学习 最初 我使用了 Tensorflow 的预编译
  • 火炬。 pin_memory 在 Dataloader 中如何工作?

    我想了解 Dataloader 中的 pin memory 是如何工作的 根据文档 pin memory bool optional If True the data loader will copy tensors into CUDA p
  • 为什么在 cmd 中安装任何 python 模块时会收到这些错误“警告:忽略无效的分发 -yproj ”

    警告 忽略无效的分发 yproj c users space junk appdata local programs python python310 lib site packages 警告 忽略无效的分发 yproj c users s
  • Raspberry 上的 Libtorch 无法加载 pt 文件,但可以在 ubuntu 上运行

    我正在尝试在 Raspberry PI 上使用 libtorch 构建 C 程序 该程序在 Ubuntu 上运行 但在 Raspberry 上构建时出现以下错误 error use of deleted function void torc

随机推荐