将类对象添加到 Pytorch Dataloader:批次必须包含张量

2023-12-02

我有一个自定义 Pytorch 数据集,它返回一个包含类对象“查询”的字典。

class QueryDataset(torch.utils.data.Dataset):

    def __init__(self, queries, values, targets):
        super(QueryDataset).__init__()
        self.queries = queries
        self.values = values
        self.targets = targets

    def __len__(self):
        return self.values.shape[0]

    def __getitem__(self, idx):
        sample = DeviceDict({'query': self.queries[idx],
                             "values": self.values[idx],
                             "targets": self.targets[idx]})
        return sample

问题是,当我将查询放入数据加载器时,我得到default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'query.Query'>。有没有办法在我的数据加载器中拥有一个类对象?它爆炸于next(iterator)在下面的代码中。

train_queries = QueryDataset(train_queries)
train_loader = torch.utils.data.DataLoader(train_queries,
                                           batch_size=10],
                                           shuffle=True,
                                           drop_last=False)
for i in range(epochs):
    iterator = iter(train_loader)
    for i in range(len(train_loader)):
        batch = next(iterator)
        out = model(batch)
        loss = criterion(out["pred"], batch["targets"])
        self.optimizer.zero_grad()
        loss.sum().backward()
        self.optimizer.step()

你需要定义你自己的科莱特_fn为此。 一个草率的方法只是为了向您展示这里的东西是如何工作的,会是这样的:

import torch
class DeviceDict:
    def __init__(self, data):
        self.data = data 

    def print_data(self):
        print(self.data)

class QueryDataset(torch.utils.data.Dataset):

    def __init__(self, queries, values, targets):
        super(QueryDataset).__init__()
        self.queries = queries
        self.values = values
        self.targets = targets

    def __len__(self):
        return 5

    def __getitem__(self, idx):
        sample = {'query': self.queries[idx],
                 "values": self.values[idx],
                 "targets": self.targets[idx]}
        return sample

def custom_collate(dict):
    return DeviceDict(dict)

dt = QueryDataset("q","v","t")
dl = torch.utils.data.DataLoader(dtt,batch_size=1,collate_fn=custom_collate)
t = next(iter(dl))
t.print_data()

基本上colate_fn允许您实现自定义批处理或添加对自定义数据类型的支持,如我之前提供的链接中所述。
正如你所看到的,它只是显示了概念,你需要根据自己的需要更改它。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

将类对象添加到 Pytorch Dataloader:批次必须包含张量 的相关文章

随机推荐

  • ImageIcon 不会更新具有相同 URL 的新图像

    我想使用 JLabel Icon 来显示来自我网站的图像 http xxx xxx xxx xxx java pic test jpg 我有一个刷新按钮来新建一个新的 JLabel 和 ImageIcon 为了获取最新的图像 程序运行成功
  • matplotlib 中的 Pandas 自动日期时间格式

    我经常在一个图中绘制来自不同来源的多个时间序列数据 其中一些需要使用 matplotlib 格式化 x 轴时 我使用 matplotlibautofmt xdate 但我更喜欢 pandas 的自动格式化 我知道我可以使用手动设置格式set
  • BeautifulSoup 不同的解析器

    任何人都可以详细说明 html parser 和 html5lib 等解析器之间的区别吗 我偶然发现了一个奇怪的行为 当使用 html parser 时 它会忽略特定位置的所有标签 看看这段代码 from bs4 import Beauti
  • 在 jquery DataTables 中跳过一行的渲染

    如果在初始化期间满足条件 我想跳过行渲染 但我不知道到底将其放置在哪里 我应该把它放进去吗fnCreatedRow or fnPreDrawCallback 我怎样才能做到这一点 这是我的代码 var users tbl users tbl
  • Cygnus 启动错误:ClassNotFoundException

    我的环境是CentOS 6 6的VM 我按照中的说明进行操作https github com telefonicaid fiware cygnus blob master doc quick start guide md安装天鹅座 还安装了
  • 如何使用 Yocto 生成适用于 Windows 的工具链?

    关于我的最后一个问题我问如何获得 Qt 工具链 我在 Linux 主机上尝试过 它可以工作 现在我需要知道如何使该工具链在 Windows 平台上工作 或者我需要什么 Yocto 设置来生成 Qt Windows SDK 安装程序 Woul
  • 使用 numpy.vectorize() 旋转 NumPy 数组的所有元素

    我正处于学习 NumPy 的开始阶段 我有一个 3x3 矩阵的 Numpy 数组 我想创建一个新数组 其中每个矩阵都旋转 90 度 我研究过这个answer但我仍然不明白我做错了什么 import numpy as np 3x3 m np
  • virtualenv pip mysqldb mac os X python

    我试过这个http jazstudios blogspot com 2010 07 installing mysql python mysqldb in html提示在 virtualenv 名为dogme 这篇文章指出了两件重要的事情 e
  • 如何在按下后以编程方式关闭 SearchView?

    我有同样的问题 我发现here我将重申这一点 因为该解决方案并不是 100 完全符合我的需要 目前 我的应用程序的操作栏中有一个 SearchView 当我单击搜索图标时 SearchView 会展开 并且键盘会按预期弹出 单击 Searc
  • 使用未渲染的控件的视觉画笔?

    我现在正在考虑一个想法 但碰壁了 我正在使用控制台应用程序在内存中创建一个视觉控件 准确地说是 DevExpress 图表控件 然后我尝试使用 VisualBrush 将该控件保存到图像中 但它不起作用 因为 我假设 该控件没有被吸引到屏幕
  • Qt4:使全屏窗口无法绕过(锁定屏幕)?

    我的应用程序是一个操作系统锁定屏幕 如 GDM 的锁定屏幕或 KDE 的锁定屏幕 因此我试图使其具有类似的功能 我试图让我的应用程序的窗口悬停在上面all其他窗口并禁用 拦截所有键盘快捷键 ALT TAB CTRL ALT D等 这会导致它
  • 当包含长文本视图时,滚动视图在 ics(android 4.0)上非常慢

    这是我的问题 我正在开发一个新闻应用程序 我使用滚动视图包装文本视图来显示新闻内容 但我发现当textview很长时 在android 4 0 ics上滚动非常慢 并且文本越长 滚动越慢 在 Android 2 3 设备上 一切都如预期的那
  • asp.net web API HTTP PUT 方法

    我有一些资源 UserProfile public UserProfile public string Email get set public string Password get set 我想分别更改电子邮件和密码 同一时间只能为用户
  • Python,睡眠一些代码而不是全部

    我遇到一种情况 在代码中的某个时刻我想触发多个计时器 代码将继续运行 但在某个时刻这些函数将触发并从 给定列表中删除一个项目 类似 但不完全像下面的代码 问题是 我希望这些函数等待一定的时间 我知道如何使用睡眠的唯一方法是使用睡眠 但是当我
  • 如何在 Node (\u00f6) 中转义 UTF-8 字符?

    我有一个使用 ISO Latin 编码的属性文件 但使用特殊字符作为 UTF 8 转义序列 例如以下字符串 Einstellungen l u00f6schen 我尝试了很多不同的组合iconv punycode and JSON pars
  • JavaScript - 如何为选定的文本设置标记?

    我需要highlight选定的文本JavaScript 没有 jQuery 并且有control points or markers 左和右 我真的不知道如何称呼它们 就像在手机上一样 所以我可以随时通过拖动任何控制点来扩展选择 例子 ht
  • Slick 3.0.0 - 仅更新包含非空值的行

    有一个包含列的表 class Data tag Tag extends Table DataRow tag data def id column Int id O PrimaryKey def name column String name
  • 一个简单的 xml 元素如何解组为 golang 结构?

    假设以下 xml 元素具有属性和浮点值
  • 我可以混淆已编译的 .NET 可执行文件/程序集吗?

    所以我试图在编译后混淆我的程序 我很确定你就是这样做的 我正在使用一个非常流行的免费软件 名为 EazFuscator 它有一个很好的小命令行实用程序 所以如果我去 Eazfuscator NET MyProgram exe 它会成功地混淆
  • 将类对象添加到 Pytorch Dataloader:批次必须包含张量

    我有一个自定义 Pytorch 数据集 它返回一个包含类对象 查询 的字典 class QueryDataset torch utils data Dataset def init self queries values targets s