根据任务需求自制数据集:Pytorch自定义数据集Dataset代码示例(有监督学习,输入输出均为图像)

2023-11-12

一、使用torchvision.io读取照片

import numpy as np
import torch
from PIL import Image
import numpy
from matplotlib import pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms,utils
import warnings
import pandas as pd
import os
import sklearn
from skimage import io,transform
import yaml
import pathlib
from torchvision.io import image



warnings.filterwarnings("ignore")
np.printoptions(np.inf)

gpu_is_available=torch.cuda.is_available()
print("GPU is {}".format( "available" if gpu_is_available else "not available"))


def read_yaml_data():
    file_path='./environments.yaml'
    with open(file_path, 'r', encoding='utf-8') as f:
        data = yaml.load(f, Loader=yaml.FullLoader)
        # print(data)
    return data


def read_imgs_paths():
    data_paths=read_yaml_data()['data_path']
    train_hazy_dir=data_paths['train_hazy_dir']
    train_gt_dir=data_paths['train_gt_dir']
    val_hazy_dir=data_paths['val_hazy_dir']
    val_gt_dir=data_paths['val_gt_dir']
    # print(data_paths)
    train_hazy_paths=list(pathlib.Path(train_hazy_dir).glob('*'))
    train_hazy_paths=[str(i) for i in train_hazy_paths]
    train_gt_paths=list(pathlib.Path(train_gt_dir).glob('*'))
    train_gt_paths=[str(i) for i in train_gt_paths]
    val_hazy_paths=list(pathlib.Path(val_hazy_dir).glob('*'))
    val_hazy_paths=[str(i) for i in val_hazy_paths]
    val_gt_paths=list(pathlib.Path(val_gt_dir).glob('*'))
    val_gt_paths=[str(i) for i in val_gt_paths]
    train_hazy_paths.sort()
    train_gt_paths.sort()
    val_hazy_paths.sort()
    val_gt_paths.sort()
    # print(train_hazy_paths)
    # print(train_gt_paths)
    # print(val_hazy_paths)
    # print(val_gt_paths)
    return (train_hazy_paths,train_gt_paths),(val_hazy_paths,val_gt_paths)


class Dehazing_Dataset(Dataset):  # data sample: {'image':image,'landmarks':landmarks}
    def __init__(self,hazy_paths,gt_paths,transform=None):
        super(Dehazing_Dataset, self).__init__()
        self.hazy_paths=hazy_paths
        self.gt_paths=gt_paths
        self.transform=transform

    def __len__(self):  # nums of data
        return len(self.hazy_paths)

    def __getitem__(self, item):  # get a sample
        hazy_img=image.read_image(self.hazy_paths[item])/255.0  # <class 'torch.Tensor'>
        gt_img=image.read_image(self.gt_paths[item])/255.0

        if self.transform:
            hazy_img=self.transform(hazy_img)
            gt_img=self.transform(gt_img)
        return hazy_img,gt_img


def get_dataset():
    (train_hazy_paths, train_gt_paths), (val_hazy_paths, val_gt_paths) = read_imgs_paths()
    train_dataset = Dehazing_Dataset(train_hazy_paths, train_gt_paths,
                                     transform=transforms.Compose([transforms.RandomCrop(size=(256,287))]))
    val_dataset = Dehazing_Dataset(val_hazy_paths, val_gt_paths,
                                     transform=transforms.Compose([transforms.RandomCrop(size=(256,287))]))

    # for i in range(len(train_dataset)):
    #     sample=train_dataset[i]
    #     show_img(sample)

    train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True,num_workers=0)
    val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=True,num_workers=0)
    return train_dataloader,val_dataloader


def show_img(sample):
    hazy,gt=sample[0],sample[1]  # c,h,w
    hazy = hazy.permute(1, 2, 0)
    gt = gt.permute(1, 2, 0)
    plt.figure(figsize=(10,15))
    for i in range(2):
        plt.subplot(1,2,i+1)
        plt.axis('off')
        if i==0:
            plt.imshow(hazy)
        else:
            plt.imshow(gt)
    plt.show()


if __name__=='__main__':
    train_dataloader, val_dataloader=get_dataset()
    for i_batch,sample_batch in enumerate(train_dataloader):
        print(type(sample_batch))  # <class 'list'>
        print(sample_batch[0].size())  # torch.Size([2, 3, 256, 287])
        print(sample_batch[1].size())  # torch.Size([2, 3, 256, 287])

二、使用PIL读取照片

import numpy as np
import torch
from PIL import Image
from matplotlib import pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms,utils
import warnings
import pandas as pd
import os
import sklearn
from skimage import io,transform
import yaml
import pathlib


warnings.filterwarnings("ignore")
np.printoptions(np.inf)

gpu_is_available=torch.cuda.is_available()
print("GPU is {}".format( "available" if gpu_is_available else "not available"))


def read_yaml_data():
    file_path='./environments.yaml'
    with open(file_path, 'r', encoding='utf-8') as f:
        data = yaml.load(f, Loader=yaml.FullLoader)
        # print(data)
    return data


def read_imgs_paths():
    data_paths=read_yaml_data()['data_path']
    train_hazy_dir=data_paths['train_hazy_dir']
    train_gt_dir=data_paths['train_gt_dir']
    val_hazy_dir=data_paths['val_hazy_dir']
    val_gt_dir=data_paths['val_gt_dir']
    # print(data_paths)
    train_hazy_paths=list(pathlib.Path(train_hazy_dir).glob('*'))
    train_hazy_paths=[str(i) for i in train_hazy_paths]
    train_gt_paths=list(pathlib.Path(train_gt_dir).glob('*'))
    train_gt_paths=[str(i) for i in train_gt_paths]
    val_hazy_paths=list(pathlib.Path(val_hazy_dir).glob('*'))
    val_hazy_paths=[str(i) for i in val_hazy_paths]
    val_gt_paths=list(pathlib.Path(val_gt_dir).glob('*'))
    val_gt_paths=[str(i) for i in val_gt_paths]
    train_hazy_paths.sort()
    train_gt_paths.sort()
    val_hazy_paths.sort()
    val_gt_paths.sort()
    # print(train_hazy_paths)
    # print(train_gt_paths)
    # print(val_hazy_paths)
    # print(val_gt_paths)
    return (train_hazy_paths,train_gt_paths),(val_hazy_paths,val_gt_paths)


class Dehazing_Dataset(Dataset):  # data sample: {'image':image,'landmarks':landmarks}
    def __init__(self,hazy_paths,gt_paths,transform=None):
        super(Dehazing_Dataset, self).__init__()
        self.hazy_paths=hazy_paths
        self.gt_paths=gt_paths
        self.transform=transform

    def __len__(self):  # nums of data
        return len(self.hazy_paths)

    def __getitem__(self, item):  # get a sample
        hazy_img=io.imread(self.hazy_paths[item])/255.0  # <class 'numpy.ndarray'>
        gt_img=io.imread(self.gt_paths[item])/255.0

        if self.transform:
            hazy_img=self.transform(hazy_img)
            gt_img=self.transform(gt_img)
        return hazy_img,gt_img


def get_dataset():
    (train_hazy_paths, train_gt_paths), (val_hazy_paths, val_gt_paths) = read_imgs_paths()
    train_dataset = Dehazing_Dataset(train_hazy_paths, train_gt_paths,
                                     transform=transforms.Compose([transforms.ToTensor(),transforms.RandomCrop(size=(256,287)),]))
    val_dataset = Dehazing_Dataset(val_hazy_paths, val_gt_paths,
                                     transform=transforms.Compose([transforms.ToTensor(),transforms.RandomCrop(size=(256,287)),]))

    # for i in range(len(train_dataset)):
    #     sample=train_dataset[i]
    #     show_img(sample)

    train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True,num_workers=0)
    val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=True,num_workers=0)
    return train_dataloader,val_dataloader


def show_img(sample):
    hazy,gt=sample[0],sample[1]  # c,h,w
    hazy = hazy.permute(1, 2, 0)
    gt = gt.permute(1, 2, 0)
    plt.figure(figsize=(10,15))
    for i in range(2):
        plt.subplot(1,2,i+1)
        plt.axis('off')
        if i==0:
            plt.imshow(hazy)
        else:
            plt.imshow(gt)
    plt.show()


if __name__=='__main__':
    train_dataloader, val_dataloader=get_dataset()
    for i_batch,sample_batch in enumerate(train_dataloader):
        print(type(sample_batch))  # <class 'list'>
        print(sample_batch[0].size())  # torch.Size([2, 3, 256, 287])
        print(sample_batch[1].size())  # torch.Size([2, 3, 256, 287])

注意:
1.Pytorch读取图像数据的集中方式,可参考:链接: https://blog.csdn.net/qq_43665602/article/details/126281393
2.使用torchvision.io和PIL两种方式读取的数据范围为[0,255],并未进行归一化,我们可根据自己的需求对其进行归一化。

  • 方式一:transform.ToTensor()会自行将数据范围归一化为[0,1];
  • 方式二:transform.Normalize(mean,std)可通过调整合适的参数值得到自己想要的归一化结果;
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

根据任务需求自制数据集:Pytorch自定义数据集Dataset代码示例(有监督学习,输入输出均为图像) 的相关文章

随机推荐

  • Python基础-48-文本处理(逗号分隔值CSV)

    前言 python自带模块csv可以将数据以csv格式输出到文件 也可以将csv数据读回 列表 元组数据写入和读取 代码部分 coding utf 8 import csv data也可以为列表 data 1 suner001 b12345
  • Qt布局管理器

    布局管理器 利用布局管理器做出如下界面效果 标签与输入框设置伙伴关系 新建桌面应用程序 项目名TestLayout 基类QWidget 类名Widget 不勾选创建界面 include widget h include
  • ubuntu16.04.4 + cuda + cudnn + 环境变量(path)

    仅仅是总结 网上教程很多 勿喷 谢谢 时间2018年7月13日 环境 ubuntu16 04 4 注意 目前ubuntu上CUDA安装只支持该版本 64位 显卡英伟达720M 没错 很古老吧 今天一看 发现这个写的太烂了 传送门 一位大哥写
  • golang-gvm

    https mp weixin qq com s SEPP56sr16bep4C S0TLgA 详细介绍 https mp weixin qq com s biz MzAxMTA4Njc0OQ mid 2651438277 idx 4 sn
  • Android获取IMEI和MEID

    在破解微信数据库时 需要获取手机的DeviceId 但是有时会出现打不开的情况 报出file is not a database while compiling select count from sqlite master的异常 这时发现
  • linux怎么关闭超线程模式,Linux动态启用/禁用超线程技术的方法详解

    前言 intel的超线程技术能让一个物理核上并行执行两个线程 大多数情况下能提高硬件资源的利用率 增强系统性能 对于cpu密集型的数值程序 超线程技术可能会导致整体程序性能下降 鉴于此 执行OpenMP或者MPI数值程序时建议关闭超线程技术
  • “基于机器学习算法的推荐系统” 在软件静态分析领域的应用方法

    一 软件静态分析背景 软件静态分析的相当部分的内容就是发现代码中的缺陷 缺陷的形式往往五花八门 各式各样 每当发现一个缺陷 测试人员首先会感到高兴 终于抓到了一条 虫 可继而很可能会感到心虚 因为 在现有技术条件下 一条软件行业的规律是仍然
  • C语言 分割bin文件程序

    file main c author Earlybird version V1 0 0 date 30 May 2022 brief 分割bin文件为指定大小文件 attention Copyright c 2022 INESA Group
  • c++模板编程-模板类的特例化和部分特化

    类模板可以对某一个模板参数进行特化 这使得我们可以对某一个类型进行优化 你最好真是在优化 或者是针对某一个进行类型实例化后的特殊处理 全特化 如我们有以下一个简单的类模板 它提供两个公开函数 calculate计算两个T类型并返回 prin
  • JS对字符串的操作

    走进前端行业已有两年之久 对于字符串的操作也是家常便饭了 但也总在查查找找 如今对于我这个强迫症患者开始爆发了 对字符串的操作做以下整理 废话不多说直接走起来 1 字符串转换 字符串转换是最基础的要求和工作 你可以将任何类型的数据都转换为字
  • 爬虫工具之Beautiful Soup4

    Beautiful Soup4 BS4 是Python的一个第三方库 用来从HTML和XML中提取数据 安装 使用Beautiful Soup4提取HTML内容 一般要经过以下两步 1 处理源代码生成BeautifulSoup对象 这里的
  • 位运算的实践

    一 只出现一次的数字 III 1 1题目 给定一个整数数组 nums 其中恰好有两个元素只出现一次 其余所有元素均出现两次 找出只出现一次的那两个元素 你可以按 任意顺序 返回答案 进阶 你的算法应该具有线性时间复杂度 你能否仅使用常数空间
  • 深度学习实时表情识别

    背景 计算机动画代理和机器人为人机交互带来了新的维度 这使得计算机如何在日常活动中影响我们的社交生活变得至关重要 面对面的交流是一个以毫秒级的时间尺度运行的实时过程 这个时间尺度的不确定性是相当大的 这使得人类和机器有必要依赖感官丰富的感知
  • 超详细的R语言热图之complexheatmap系列(1)

    获取更多R语言和生信知识 请关注公众号 医学和生信笔记 公众号后台回复R语言 即可获得海量学习资料 目录 第一章 简介 1 1 设计理念 1 2 各章节速览 第二章 单个热图 2 1 颜色 2 2 行标题 列标题 2 3 聚类 2 3 1
  • 深度访谈:“告诉我,AI对企业到底有什么价值?”

    Eden是一家连锁经营企业的负责人 最近困扰他的是遍布全国直营和加盟店的数千名员工 如何在后疫情时代把企业的运营效率通过智能化提升一个层级 AskBot团队专注企业内部智能化 用AI去辅助人解决重复高频问题 因此才有了双方下面这一系列围绕企
  • matlab分频.m,分频器m是什么意思 音响分频器m. TW那个代表高音那个代表是低音?...

    音响分频器m TW那个代表高音那个代表是低音 T是treble 的缩写 指高音 M是mediant或middle的缩写 指中音 W是woof的缩写 指低音 音箱分频器m m 什么意思 音箱分频器m m 应该是接中音喇叭负 正两端 T T 接
  • 尚硅谷周阳老师 SpringCloud第二季学习笔记

    前言 首先感谢尚硅谷周阳老师的讲解 让我对springcloud有了很好的理解 周阳老师的讲课风格真的很喜欢 内容充实也很幽默 随口一说就是一个段子 我也算是周阳老师的忠实粉丝啦 先说说课程总体内容 以下是整理的笔记 SpringCloud
  • 带宽是什么

    带宽是什么 带宽 band width 又叫频宽 是指在固定的的时间可传输的资料数量 亦即在传输管道中可以传递数据的能力 在数字设备 中 频宽通常以bps表示 即每秒可传输之位数 在模拟设备中 频宽通常以每秒传送周期或赫兹 Hz 来表示 带
  • 超好用:免费的图床

    经常写文章的小伙伴可能会头疼 图片需要一张一张的上传 费劲也耗时 今天就推荐几款超简单的图床工具 图床就是一个在网络上存储图片的地方 目的是为了节省本地服务器空间 加快图片打开速度 话不多说 进入正题 非技术手段 1 SM MS 永久存储免
  • 根据任务需求自制数据集:Pytorch自定义数据集Dataset代码示例(有监督学习,输入输出均为图像)

    自定义数据集 一 使用torchvision io读取照片 二 使用PIL读取照片 一 使用torchvision io读取照片 import numpy as np import torch from PIL import Image i