如何在 PyTorch 数据加载器中将 RGB 图像转换为灰度图像?

2024-03-06

我已经从 MNIST 数据集中下载了一些示例图像.jpg格式。现在我正在加载这些图像来测试我的预训练模型。

# transforms to apply to the data
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# MNIST dataset
test_dataset = dataset.ImageFolder(root=DATA_PATH, transform=trans)

# Data loader
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

Here DATA_PATH包含一个带有示例图像的子文件夹。

这是我的网络定义

# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.network2D = nn.Sequential(
           nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
           nn.ReLU(),
           nn.MaxPool2d(kernel_size=2, stride=2),
           nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
           nn.ReLU(),
           nn.MaxPool2d(kernel_size=2, stride=2))
        self.network1D = nn.Sequential(
           nn.Dropout(),
           nn.Linear(7 * 7 * 64, 1000),
           nn.Linear(1000, 10))

    def forward(self, x):
        out = self.network2D(x)
        out = out.reshape(out.size(0), -1)
        out = self.network1D(out)
        return out

这是我的推理部分

# Test the model
model = torch.load("mnist_weights_5.pth.tar")
model.eval()

for images, labels in test_loader:
   outputs = model(images.cuda())

当我运行此代码时,出现以下错误:

RuntimeError: Given groups=1, weight of size [32, 1, 5, 5], expected input[1, 3, 28, 28] to have 1 channels, but got 3 channels instead

据我所知,图像以 3 通道 (RGB) 的形式加载。那么我如何将它们转换为单通道dataloader?

更新: 我变了transforms包括Grayscale option

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), transforms.Grayscale(num_output_channels=1)])

但现在我得到这个错误

TypeError: img should be PIL Image. Got <class 'torch.Tensor'>

使用时ImageFolder类并且没有自定义加载器,pytorch 使用 PIL 加载图像并将其转换为 RGB。如果 torchvision 图像后端是 PIL,则默认加载器:

def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

您可以使用torchvision 的灰度变换中的函数。它将 3 通道 RGB 图像转换为 1 通道灰度图像。了解更多相关信息,请访问here https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Grayscale

示例代码如下,

import torchvision as tv
import numpy as np
import torch.utils.data as data
dataDir         = 'D:\\general\\ML_DL\\datasets\\CIFAR'
trainTransform  = tv.transforms.Compose([tv.transforms.Grayscale(num_output_channels=1),
                                    tv.transforms.ToTensor(), 
                                    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainSet        = tv.datasets.CIFAR10(dataDir, train=True, download=False, transform=trainTransform)
dataloader      = data.DataLoader(trainSet, batch_size=1, shuffle=False, num_workers=0)
images, labels  = iter(dataloader).next()
print (images.size())
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在 PyTorch 数据加载器中将 RGB 图像转换为灰度图像? 的相关文章

  • 使用 Flask-restful RequestParser 进行嵌套验证

    使用烧瓶宁静 http flask restful readthedocs org 微框架 我在构建一个RequestParser这将验证嵌套资源 假设预期的 JSON 资源格式为 a list obj1 1 obj2 2 obj3 3 o
  • Python lambda 函数没有在 for 循环中正确调用[重复]

    这个问题在这里已经有答案了 我正在尝试使用 Python 中的 Tkinter 制作一个计算器 我使用 for 循环来绘制按钮 并且尝试使用 lambda 函数 以便仅在按下按钮时调用按钮的操作 而不是在程序启动时立即调用 然而 当我尝试这
  • Pytorch 损失为 nan

    我正在尝试用 pytorch 编写我的第一个神经网络 不幸的是 当我想要得到损失时遇到了问题 出现以下错误信息 RuntimeError Function LogSoftmaxBackward0 returned nan values in
  • 是否可以将名为“None”的值添加到枚举类型?

    我可以将名为 None 的值添加到枚举中吗 例如 from enum import Enum class Color Enum None 0 represent no color at all red 1 green 2 blue 3 co
  • MySQL 的 read_sql() 非常慢

    我将 MySQL 与 pandas 和 sqlalchemy 一起使用 然而 它的速度非常慢 对于一个包含 1100 万行的表 一个简单的查询需要 11 分钟以上才能完成 哪些行动可以改善这种表现 提到的表没有主键 并且仅由一列索引 fro
  • 如何在代码中停止 autopep8 未安装消息

    我是一名新的 Python 程序员 使用 Mac 版本的 VS Code 1 45 1 创建 Django 项目 我安装了 Python 和 Django 扩展 每次我保存 Django 文件时 代码都会弹出此窗口 Formatter au
  • 通过 pyodbc 连接到 Azure SQL 数据库

    我使用 pyodbc 连接到本地 SQL 数据库 该数据库工作正常 SQLSERVERLOCAL Driver SQL Server Native Client 11 0 Server localdb v11 0 integrated se
  • python中的语音识别持续时间设置问题

    我有一个 Wav 格式的音频文件 我想转录 我的代码是 import speech recognition as sr harvard sr AudioFile speech file wav with harvard as source
  • 导入错误:无法导入名称 urandom

    我正在构建一个新的 Linux 环境 并在 Python 上看到以下错误 python c import random Traceback most recent call last File
  • __subclasses__ 没有显示任何内容

    我正在实现一个从适当的子类返回对象的函数 如果我搬家SubClass from base py 没有出现子类 subclasses 它们必须在同一个文件中吗 也许我从来没有直接导入subclass py对Python隐藏子类 我能做些什么
  • 调试 python Web 服务

    我正在使用找到的说明here http www diveintopython net http web services user agent html 尝试检查发送到我的网络服务器的 HTTP 命令 但是 我没有看到按照教程中的建议在控制
  • NumPy 数组不可 JSON 序列化

    创建 NumPy 数组并将其保存为 Django 上下文变量后 加载网页时收到以下错误 array 0 239 479 717 952 1192 1432 1667 dtype int64 is not JSON serializable
  • Tensorflow:提要字典错误:您必须为占位符张量提供值

    我有一个错误 我无法找出原因 这是代码 with tf Graph as default global step tf Variable 0 trainable False images tf placeholder tf float32
  • 如何克服 numpy.unique 的 MemoryError

    我正在使用 Numpy 版本 1 11 1 并且必须处理一个二维数组 my arr shape 25000 25000 所有值都是整数 我需要一个唯一的数组值列表 使用时lst np unique my arr 我正进入 状态 Traceb
  • 使用 JSON 可序列化枚举自动生成棉花糖模式

    创建与我的模型相同的棉花糖模式的日子已经一去不复返了 我发现这个优秀的答案 https stackoverflow com a 42892443 4097322这解释了我如何使用简单的装饰器从 SQA 模型自动生成模式 因此我实现了它并替换
  • 是否可以使用 Python 中的密码安全地加密然后解密数据?

    我在 python 程序中有一些数据 我想在使用密码写入文件之前对其进行加密 然后在使用它之前读取并解密它 我正在寻找一些可以根据密码进行加密和解密的安全对称算法 这个问题 https stackoverflow com questions
  • 使用 boto3 将 csv 文件保存到 s3

    我正在尝试写入 CSV 文件并将其保存到 s3 中的特定文件夹 存在 这是我的代码 from io import BytesIO import pandas as pd import boto3 s3 boto3 resource s3 d
  • 矩阵求逆 (3,3) python - 硬编码与 numpy.linalg.inv

    对于大量矩阵 我需要计算定义为的距离度量 尽管我确实知道强烈建议不要使用矩阵求逆 但我没有找到解决方法 因此 我尝试通过对矩阵求逆进行硬编码来提高性能 因为所有矩阵的大小均为 3 3 我预计这至少会是一个微小的改进 但事实并非如此 为什么
  • 应用程序的外观 - Py2exe / wxPython

    所以我的问题是我的应用程序的外观和感觉 因为它看起来像一个旧的外观应用程序 它是一个 wxPython 应用程序 在 python 上它运行良好并且看起来不错 但是当我使用 py2exe 将其转换为 exe 时 外观很糟糕 现在我知道如果你
  • 用于获取有关 SVN 存储库信息的 Python 库?

    我正在寻找一个可以从 SVN 存储库中提取 至少 以下信息的库 not工作副本 修订号及其作者和提交消息 每个修订版中的更改 添加 删除 修改文件 有Python库可以做到这一点吗 对于作者和提交消息 我可以解析 db revprops 0

随机推荐