使用可变批量大小加载数据?

2023-12-28

我目前正在研究基于补丁的超分辨率。大多数论文将图像分割成更小的补丁,然后使用这些补丁作为模型的输入。我能够使用自定义数据加载器创建补丁。代码如下:

import torch.utils.data as data
from torchvision.transforms import CenterCrop, ToTensor, Compose, ToPILImage, Resize, RandomHorizontalFlip, RandomVerticalFlip
from os import listdir
from os.path import join
from PIL import Image
import random
import os
import numpy as np
import torch

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg", ".bmp"])

class TrainDatasetFromFolder(data.Dataset):
    def __init__(self, dataset_dir, patch_size, is_gray, stride):
        super(TrainDatasetFromFolder, self).__init__()
        self.imageHrfilenames = []
        self.imageHrfilenames.extend(join(dataset_dir, x)
                                     for x in sorted(listdir(dataset_dir)) if is_image_file(x))
        self.is_gray = is_gray
        self.patchSize = patch_size
        self.stride = stride

    def _load_file(self, index):
        filename = self.imageHrfilenames[index]
        hr = Image.open(self.imageHrfilenames[index])
        downsizes = (1, 0.7, 0.45)
        downsize = 2
        w_ = int(hr.width * downsizes[downsize])
        h_ = int(hr.height * downsizes[downsize])
        aug = Compose([Resize([h_, w_], interpolation=Image.BICUBIC),
                       RandomHorizontalFlip(),
                       RandomVerticalFlip()])

        hr = aug(hr)
        rv = random.randint(0, 4)
        hr = hr.rotate(90*rv, expand=1)
        filename = os.path.splitext(os.path.split(filename)[-1])[0]
        return hr, filename

    def _patching(self, img):

        img = ToTensor()(img)
        LR_ = Compose([ToPILImage(), Resize(self.patchSize//2, interpolation=Image.BICUBIC), ToTensor()])

        HR_p, LR_p = [], []
        for i in range(0, img.shape[1] - self.patchSize, self.stride):
            for j in range(0, img.shape[2] - self.patchSize, self.stride):
                temp = img[:, i:i + self.patchSize, j:j + self.patchSize]
                HR_p += [temp]
                LR_p += [LR_(temp)]

        return torch.stack(LR_p),torch.stack(HR_p)

    def __getitem__(self, index):
        HR_, filename = self._load_file(index)
        LR_p, HR_p = self._patching(HR_)
        return LR_p, HR_p

    def __len__(self):
        return len(self.imageHrfilenames)

假设批量大小为 1,它获取图像并给出 size 的输出[x,3,patchsize,patchsize]。当批量大小为 2 时,我将有两个不同大小的输出[x,3,patchsize,patchsize](例如图像 1 可能给出[50,3,patchsize,patchsize],图像2可能给出[75,3,patchsize,patchsize])。为了处理这个问题,需要一个自定义的整理函数来沿着维度 0 堆叠这两个输出。整理函数如下:

def my_collate(batch):
    data = torch.cat([item[0] for item in batch],dim = 0)
    target = torch.cat([item[1] for item in batch],dim = 0)

    return [data, target]

这个整理函数沿着 x 连接(从上面的例子中,我终于得到[125,3,patchsize,pathsize]。出于训练目的,我需要使用 25 的小批量大小来训练模型。是否有任何方法或函数可以用来直接获得大小的输出[25 , 3, patchsize, pathsize]直接从数据加载器使用必要数量的图像作为数据加载器的输入?


以下代码片段可满足您的目的。

首先,我们定义一个 ToyDataset,它接受张量列表(tensors) of variable length in dimension 0。这与数据集返回的样本类似。

import torch
from torch.utils.data import Dataset
from torch.utils.data.sampler import RandomSampler

class ToyDataset(Dataset):
    def __init__(self, tensors):
        self.tensors = tensors

    def __getitem__(self, index):
        return self.tensors[index]

    def __len__(self):
        return len(tensors)

其次,我们定义一个自定义数据加载器。创建数据集和数据加载器的常见 Pytorch 二分法大致如下:dataset,您可以向其传递索引,它会从数据集中返回关联的样本。有一个sampler产生一个索引,有不同的策略来绘制索引,从而产生不同的采样器。采样器由batch_sampler一次绘制多个索引(与batch_size指定的数量相同)。有一个dataloader它结合了采样器和数据集,让您可以迭代数据集,重要的是数据加载器还拥有一个函数(collate_fn),它指定如何组合使用来自batch_sampler的索引从数据集中检索的多个样本。对于您的用例,通常的 PyTorch 二分法效果不佳,因为我们需要绘制索引,直到与索引关联的对象超过我们期望的累积大小,而不是绘制固定数量的索引。这意味着我们需要立即检查对象并使用这些知识来决定是否返回批次或保留绘图索引。这就是下面的自定义数据加载器的作用:

class CustomLoader(object):

    def __init__(self, dataset, my_bsz, drop_last=True):
        self.ds = dataset
        self.my_bsz = my_bsz
        self.drop_last = drop_last
        self.sampler = RandomSampler(dataset)

    def __iter__(self):
        batch = torch.Tensor()
        for idx in self.sampler:
            batch = torch.cat([batch, self.ds[idx]])
            while batch.size(0) >= self.my_bsz:
                if batch.size(0) == self.my_bsz:
                    yield batch
                    batch = torch.Tensor()
                else:
                    return_batch, batch = batch.split([self.my_bsz,batch.size(0)-self.my_bsz])
                    yield return_batch
        if batch.size(0) > 0 and not self.drop_last:
            yield batch

在这里,我们迭代数据集,在绘制索引并加载关联对象后,我们将其连接到我们之前绘制的张量(batch)。我们继续这样做,直到达到所需的尺寸,这样我们就可以切割并生产一批。我们保留行batch,我们没有屈服。因为可能会出现单个实例超过所需的batch_size的情况,所以我们使用while loop.

您可以修改这个最小CustomDataloader以 PyTorch 数据加载器的风格添加更多功能。也不需要使用 RandomSampler 来绘制索引,其他的也同样可以工作。如果您的数据很大,通过使用列表并跟踪其张量的累积长度,也可以避免重复的连接。

这是一个示例,演示了它的工作原理:

patch_size = 5
channels = 3
dim0sizes = torch.LongTensor(100).random_(1, 100)
data = torch.randn(size=(dim0sizes.sum(), channels, patch_size, patch_size))
tensors = torch.split(data, list(dim0sizes))

ds = ToyDataset(tensors)
dl = CustomLoader(ds, my_bsz=250, drop_last=False)
for i in dl:
    print(i.size(0))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用可变批量大小加载数据? 的相关文章

随机推荐

  • 是否可以告诉自动映射器在运行时忽略映射?

    我正在使用 Entity Framework 6 和 Automapper 将实体映射到 dtos 我有这个型号 public class PersonDto public int Id get set public string Name
  • MathJax 方程换行

    嘿 如果包含的元素具有固定大小 有谁知道让 MathJax 自动换行方程的好方法 MathJax v2 0 现在包括针对长显示方程的自动 可选 换行 它是由linebreaks的部分HTML CSS您的配置块 请参阅MathJax 文档 h
  • 在 TypeScript 中解构对象时重命名剩余属性变量

    EDIT 我在github上开了一个与此相关的问题 https github com Microsoft TypeScript issues 21265 https github com Microsoft TypeScript issue
  • PostgreSQL 从 9.1 升级到 9.4 后性能下降

    将 Postgres 9 1 升级到 9 4 后 我的性能变得非常慢 下面是两个查询的示例 它们的运行速度明显慢得多 注意 我意识到这些查询可能可以被重写以更有效地工作 但是我主要担心的是升级到较新版本的 Postgres 后 它们的运行速
  • 差异化包装

    升级应用程序时 Test ServiceFabricApplicationPackage命令会对版本号未更改的每个代码包抛出错误 这表示内容已更改 即使代码未更改 我知道有一个功能可以创建部分包 但我无法使用它 我的问题是 如何检查代码包内
  • 如何在其他工作表的应用程序脚本中请求或获得谷歌电子表格访问权限?

    我正在为我的自定义函数编写 A 电子表格的应用程序脚本 并尝试使用从那里获取 B 电子表格中的值openUrl 然而 我得到了ERROR当我使用自定义函数时在电子表格中 在谷歌文档中 它说 如果您的自定义函数抛出错误消息 You do no
  • 使用powershell在其他域上查找“网络用户”?

    我想做的是 net user user1 DOMAIN 但是 我想为计算机未加入但可以访问的域执行此操作 用户分布在 DOMAIN1 和 DOMAIN2 中 我运行它的计算机已加入 DOMAIN1 但会在 DOMAIN2 上查找用户 这可以
  • 在 mongodb 的嵌套数组中插入数据[重复]

    这个问题在这里已经有答案了 可能的重复 MongoDB 更新嵌套数组中的字段 https stackoverflow com questions 9611833 mongodb updating fields in nested array
  • Safari 中的垂直居中

    我在 Safari 中使用 margin auto 0 时遇到垂直居中问题 在嵌套在带有 display inline flex 的 div 内的 div 上 它在 Firefox Chrome Opera 中工作得很好 但在 Safari
  • Travis CI 失败,因为无法接受许可证约束布局

    在我写这个问题之前 我已经搜索过同样的问题 他们确实有导出许可证 因为仍然使用 alpha 版本的约束布局 但现在android已经发布了约束布局的稳定版本 我尝试了很多设置但仍然失败 我最新的 travis yml language an
  • Django - 显示图像字段

    我刚刚开始使用 Django 还没有找到很多关于如何显示的信息imageField 所以我做了这个 模型 py class Car models Model name models CharField max length 255 pric
  • 如何判断闭合路径是否包含给定点?

    在 Android 中 我有一个 Path 对象 我碰巧知道它定义了一条闭合路径 并且我需要弄清楚给定点是否包含在路径中 我所希望的是类似的东西 路径 contains int x int y 但这似乎不存在 我寻找这个的具体原因是因为我在
  • 如何使用 signalr 将 json 对象发送到 .net 服务器

    我正在开发一个 Angular 应用程序 我必须使用 netcore 服务器和 signalR 将数据从角度形式发送到外部服务器 我可以使用信号集线器在 Angular 客户端和控制器之间建立连接 但我很困惑如何将 json 对象从客户端发
  • 在 bash 中选择不同的可执行文件

    当我想跑步的时候make为了生成一些可执行文件 它总是使用 Sunmake位于 在 usr local bin make而不是 GNU make 可以在以下位置找到 usr sfw bin gmake 我如何告诉操作系统使用 GNU mak
  • TkInter:了解解除绑定功能

    TkInter 是否unbind http effbot org tkinterbook widget htm Tkinter Widget unbind method函数阻止应用它的小部件将更多事件绑定到小部件 澄清 假设我在程序的早期将
  • Python 中以下代码有什么问题?

    我试图对一个字段实施约束 但它不会导致约束验证 而是允许保存记录而不显示任何约束消息 def check contact number self cr uid ids context None for rec in self browse
  • 在 AOSP Android 6.0 上更新 WebView

    我正在开发基于 AOSP Android 6 0 Marshmallow 的设备 我想将标准 Android webview 更新到最新版本以使用最新的 JavaScript 为此我更换了external chromium webview
  • 使 JButton 在 JTable 内可单击

    这是我想做的事情的屏幕截图 发生的情况是 JButton 显示正确 但当我单击它时没有任何反应 经过一番搜索 我发现Object由返回table getValueAt 是一个字符串而不是 JButton 这是代码 tblResult new
  • 是否有“纯粹适用的任一”的标准名称或实现?

    我经常发现我所谓的 纯粹应用性 的用处Either i e Either与Applicative只要我们不实现一个实例就可用Monad实例也是如此 newtype AEither e a AEither unAEither Either e
  • 使用可变批量大小加载数据?

    我目前正在研究基于补丁的超分辨率 大多数论文将图像分割成更小的补丁 然后使用这些补丁作为模型的输入 我能够使用自定义数据加载器创建补丁 代码如下 import torch utils data as data from torchvisio