PyTorch 在 TensorDataset 上进行转换

2024-01-20

我在用着TensorDataset https://pytorch.org/docs/stable/data.html?highlight=tensordataset从 numpy 数组创建数据集。

# convert numpy arrays to pytorch tensors
X_train = torch.stack([torch.from_numpy(np.array(i)) for i in X_train])
y_train = torch.stack([torch.from_numpy(np.array(i)) for i in y_train])

# reshape into [C, H, W]
X_train = X_train.reshape((-1, 1, 28, 28)).float()

# create dataset and dataloaders
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)

如何应用数据增强(转变 https://pytorch.org/docs/stable/torchvision/transforms.html) to TensorDataset?

例如,使用ImageFolder https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder,我可以指定转换作为其参数之一torchvision.datasets.ImageFolder(root, transform=...).

根据这个回复 https://discuss.pytorch.org/t/anything-like-transformer-for-tensordataset/928/2?u=kharshit由 PyTorch 团队成员之一提出,默认情况下不支持。有没有其他方法可以做到这一点?

请随意询问是否需要更多代码来解释问题。


默认情况下不支持转换TensorDataset。但我们可以创建自定义类来添加该选项。但是,正如我已经提到的,大多数转换都是为PIL.Image。但无论如何,这是一个非常简单的 MNIST 示例,具有非常虚拟的变换。带有 MNIST 的 csv 文件here https://pjreddie.com/media/files/mnist_train.csv.

Code:

import numpy as np
import torch
from torch.utils.data import Dataset, TensorDataset

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

# Import mnist dataset from cvs file and convert it to torch tensor

with open('mnist_train.csv', 'r') as f:
    mnist_train = f.readlines()

# Images
X_train = np.array([[float(j) for j in i.strip().split(',')][1:] for i in mnist_train])
X_train = X_train.reshape((-1, 1, 28, 28))
X_train = torch.tensor(X_train)

# Labels
y_train = np.array([int(i[0]) for i in mnist_train])
y_train = y_train.reshape(y_train.shape[0], 1)
y_train = torch.tensor(y_train)

del mnist_train


class CustomTensorDataset(Dataset):
    """TensorDataset with support of transforms.
    """
    def __init__(self, tensors, transform=None):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
        self.transform = transform

    def __getitem__(self, index):
        x = self.tensors[0][index]

        if self.transform:
            x = self.transform(x)

        y = self.tensors[1][index]

        return x, y

    def __len__(self):
        return self.tensors[0].size(0)


def imshow(img, title=''):
    """Plot the image batch.
    """
    plt.figure(figsize=(10, 10))
    plt.title(title)
    plt.imshow(np.transpose( img.numpy(), (1, 2, 0)), cmap='gray')
    plt.show()


# Dataset w/o any tranformations
train_dataset_normal = CustomTensorDataset(tensors=(X_train, y_train), transform=None)
train_loader = torch.utils.data.DataLoader(train_dataset_normal, batch_size=16)

# iterate
for i, data in enumerate(train_loader):
    x, y = data  
    imshow(torchvision.utils.make_grid(x, 4), title='Normal')
    break  # we need just one batch


# Let's add some transforms

# Dataset with flipping tranformations

def vflip(tensor):
    """Flips tensor vertically.
    """
    tensor = tensor.flip(1)
    return tensor


def hflip(tensor):
    """Flips tensor horizontally.
    """
    tensor = tensor.flip(2)
    return tensor


train_dataset_vf = CustomTensorDataset(tensors=(X_train, y_train), transform=vflip)
train_loader = torch.utils.data.DataLoader(train_dataset_vf, batch_size=16)

result = []

for i, data in enumerate(train_loader):
    x, y = data  
    imshow(torchvision.utils.make_grid(x, 4), title='Vertical flip')
    break


train_dataset_hf = CustomTensorDataset(tensors=(X_train, y_train), transform=hflip)
train_loader = torch.utils.data.DataLoader(train_dataset_hf, batch_size=16)

result = []

for i, data in enumerate(train_loader):
    x, y = data  
    imshow(torchvision.utils.make_grid(x, 4), title='Horizontal flip')
    break

Output:

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch 在 TensorDataset 上进行转换 的相关文章

随机推荐

  • 我可以将参数传递给 rake db:seed 吗?

    我的一部分seeds rb将大量数据加载到数据库中 我希望能够有选择地加载这些数据 例如 rake db seed or rake db seed 0 只会加载运行网站所需的数据 而 rake db seed 1 也会将我的大数据文件加载到
  • Ember 数据无法读取未定义的属性“async”

    将 Ember v1 8 beta 3 与 Ember Data 1 0 beta 10 一起使用 您会收到以下错误 Error while processing route index Cannot read property async
  • 生成指定范围内的随机数 - 各种情况(int、float、inclusive、exclusive)

    given a Math random 返回 0 1 之间的数字的函数min max值来指定范围 我们如何为以下情况生成数字 我们想要的案例integer A min max B min max return Math floor Math
  • 如何更改jquery对话框按钮

    我想用我自己的按钮图像替换 jquery 对话框按钮 这样做最简洁的方法是什么 按钮上不会覆盖任何文本 我正在使用 jquery 1 4 2 和 jquery ui 1 8 1 不要应用 jQuery UI 使用的 CSS 选择器 使用具有
  • Global.asax 事件:Application_OnPostAuthenticateRequest

    我在用Application OnPostAuthenticateRequest事件在global asax to get a 经过身份验证的用户的角色和权限我还制作了自定义主体类来获取用户详细信息以及角色和权限 b 获取对该用户而言保持不
  • Jquery 动画无法正确处理列表项

    我有一个垂直的项目列表 每个项目都有一个删除按钮 当我点击其中一个的删除时 我希望下面的那些能够平滑地向上滑动 此时它们正在跳跃 下面是代码 http codepen io ovesyan19 pen chDgy http codepen
  • 这个&符号是什么意思? [复制]

    这个问题在这里已经有答案了 可能的重复 只需观看一些 Railscast 即可看到如下代码 Category Product Person each delete all 我知道它会删除这些模型的所有记录 但我不知道这是什么 delete
  • 在 iPad 上的 SwiftUI 中呈现 ActionSheet

    我已经得到了一个可以在 iPhone 设备上很好地呈现的 ActionSheet 但它在 iPad 上会崩溃 说它需要弹出窗口的位置 有人对这段代码感到幸运吗 我正在使用 iOS 13 beta 3 和 Xcode 11 beta 3 这使
  • 以编程方式将 Woocommerce 订阅开关添加到购物车

    我正在寻找一种以编程方式将两个 Woocommerce 订阅变体之间的切换添加到购物车的方法 我们正在构建一个无头 WP 网站 因此我不想使用链接来完成此操作 如中所述这个问题 https stackoverflow com questio
  • django 形式的附加字段

    我正在 Django 中创建一个表单 当我POST数据时 数据自然就发送出去了 我的问题是 我想向 POST 数据传递一个附加属性 该属性不是任何表单字段 而是一个附加属性 这样我以后就可以做类似的事情 伪代码 def form view
  • 如何在android中使用摘要式身份验证?

    我正在创建一个 Android 应用程序 通过服务器验证用户名 密码 最初服务器正在实施Basic身份验证 所以我的代码工作正常 但现在服务器已更改为Digest身份验证 所以我的旧代码不起作用 使用时应该做哪些改变Digest验证 我的代
  • 要求(img 路径)不起作用/找不到模块“。”反应js

    您好 我正在尝试使用react js 将一组图像从 this state 映射到图像标签 我遇到错误 找不到模块 这是错误 错误 找不到模块 webpack缺少模块 src components thumbnails js 26 23 24
  • 有没有办法以编程方式创建 Twiml Bin?

    我想要制作一个应用程序 用户可以在其中输入电话号码和消息 然后我可以让 Twilio 向该电话号码发送一条带有合成文本的消息 一个TwiML 代码示例 https www twilio com docs api twiml say我正在使用
  • 在python中求小矩阵

    def getMatrixMinor m i j return row j row j 1 for row in m i m i 1 上面是我在堆栈溢出中找到的代码 以便找到矩阵的逆 但是 我对 python 确实很陌生 任何人都可以解释背
  • iphone sdk中Delegate的使用

    有人可以解释一下委托在 iphone sdk 中到底是如何工作的吗 一个简单的示例如何使用委托以及使用委托的优点是什么 委托模式 http en wikipedia org wiki Delegation pattern在iPhone SD
  • 如何在 python 中与终端交互

    我正在写一个小脚本 该脚本应打开 3 个终端并独立与这些终端交互 我很清楚子流程是做到这一点的最佳方法 到目前为止我所做的 usr bin env python import subprocess term1 subprocess Pope
  • ios7 UINavigationBar 一段时间后停止在状态栏下延伸

    首先 这不是关于导航栏重叠状态栏的问题 与许多其他问题一样 UINavigationBar 我的导航控制器 完全按照我的要求对齐 问题出在我的导航栏自定义背景上 背景图像 或导航栏本身 在状态栏下随机停止扩展 在我的应用程序启动几秒钟后或当
  • 给定 N 个弹珠和 M 个楼层,找到算法来找到弹珠会破裂的最低楼层

    它与这个问题相关 两颗弹珠和一座 100 层的建筑 https stackoverflow com questions 6547 two marbles但它不一样 我们要找出最好的算法来找出最小化找到最低楼层所需的最大下降的策略 这就是我的
  • Angular 2 路由器,错误:无法激活已激活的插座

    我想我根本无法理解 Angular 2 路由 我的应用程序中有这样的结构 const routes Routes path login component LoginViewComponent path main component Mai
  • PyTorch 在 TensorDataset 上进行转换

    我在用着TensorDataset https pytorch org docs stable data html highlight tensordataset从 numpy 数组创建数据集 convert numpy arrays to