一、使用torchvision.io读取照片
import numpy as np
import torch
from PIL import Image
import numpy
from matplotlib import pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms,utils
import warnings
import pandas as pd
import os
import sklearn
from skimage import io,transform
import yaml
import pathlib
from torchvision.io import image
warnings.filterwarnings("ignore")
np.printoptions(np.inf)
gpu_is_available=torch.cuda.is_available()
print("GPU is {}".format( "available" if gpu_is_available else "not available"))
def read_yaml_data():
file_path='./environments.yaml'
with open(file_path, 'r', encoding='utf-8') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
# print(data)
return data
def read_imgs_paths():
data_paths=read_yaml_data()['data_path']
train_hazy_dir=data_paths['train_hazy_dir']
train_gt_dir=data_paths['train_gt_dir']
val_hazy_dir=data_paths['val_hazy_dir']
val_gt_dir=data_paths['val_gt_dir']
# print(data_paths)
train_hazy_paths=list(pathlib.Path(train_hazy_dir).glob('*'))
train_hazy_paths=[str(i) for i in train_hazy_paths]
train_gt_paths=list(pathlib.Path(train_gt_dir).glob('*'))
train_gt_paths=[str(i) for i in train_gt_paths]
val_hazy_paths=list(pathlib.Path(val_hazy_dir).glob('*'))
val_hazy_paths=[str(i) for i in val_hazy_paths]
val_gt_paths=list(pathlib.Path(val_gt_dir).glob('*'))
val_gt_paths=[str(i) for i in val_gt_paths]
train_hazy_paths.sort()
train_gt_paths.sort()
val_hazy_paths.sort()
val_gt_paths.sort()
# print(train_hazy_paths)
# print(train_gt_paths)
# print(val_hazy_paths)
# print(val_gt_paths)
return (train_hazy_paths,train_gt_paths),(val_hazy_paths,val_gt_paths)
class Dehazing_Dataset(Dataset): # data sample: {'image':image,'landmarks':landmarks}
def __init__(self,hazy_paths,gt_paths,transform=None):
super(Dehazing_Dataset, self).__init__()
self.hazy_paths=hazy_paths
self.gt_paths=gt_paths
self.transform=transform
def __len__(self): # nums of data
return len(self.hazy_paths)
def __getitem__(self, item): # get a sample
hazy_img=image.read_image(self.hazy_paths[item])/255.0 # <class 'torch.Tensor'>
gt_img=image.read_image(self.gt_paths[item])/255.0
if self.transform:
hazy_img=self.transform(hazy_img)
gt_img=self.transform(gt_img)
return hazy_img,gt_img
def get_dataset():
(train_hazy_paths, train_gt_paths), (val_hazy_paths, val_gt_paths) = read_imgs_paths()
train_dataset = Dehazing_Dataset(train_hazy_paths, train_gt_paths,
transform=transforms.Compose([transforms.RandomCrop(size=(256,287))]))
val_dataset = Dehazing_Dataset(val_hazy_paths, val_gt_paths,
transform=transforms.Compose([transforms.RandomCrop(size=(256,287))]))
# for i in range(len(train_dataset)):
# sample=train_dataset[i]
# show_img(sample)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True,num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=True,num_workers=0)
return train_dataloader,val_dataloader
def show_img(sample):
hazy,gt=sample[0],sample[1] # c,h,w
hazy = hazy.permute(1, 2, 0)
gt = gt.permute(1, 2, 0)
plt.figure(figsize=(10,15))
for i in range(2):
plt.subplot(1,2,i+1)
plt.axis('off')
if i==0:
plt.imshow(hazy)
else:
plt.imshow(gt)
plt.show()
if __name__=='__main__':
train_dataloader, val_dataloader=get_dataset()
for i_batch,sample_batch in enumerate(train_dataloader):
print(type(sample_batch)) # <class 'list'>
print(sample_batch[0].size()) # torch.Size([2, 3, 256, 287])
print(sample_batch[1].size()) # torch.Size([2, 3, 256, 287])
二、使用PIL读取照片
import numpy as np
import torch
from PIL import Image
from matplotlib import pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms,utils
import warnings
import pandas as pd
import os
import sklearn
from skimage import io,transform
import yaml
import pathlib
warnings.filterwarnings("ignore")
np.printoptions(np.inf)
gpu_is_available=torch.cuda.is_available()
print("GPU is {}".format( "available" if gpu_is_available else "not available"))
def read_yaml_data():
file_path='./environments.yaml'
with open(file_path, 'r', encoding='utf-8') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
# print(data)
return data
def read_imgs_paths():
data_paths=read_yaml_data()['data_path']
train_hazy_dir=data_paths['train_hazy_dir']
train_gt_dir=data_paths['train_gt_dir']
val_hazy_dir=data_paths['val_hazy_dir']
val_gt_dir=data_paths['val_gt_dir']
# print(data_paths)
train_hazy_paths=list(pathlib.Path(train_hazy_dir).glob('*'))
train_hazy_paths=[str(i) for i in train_hazy_paths]
train_gt_paths=list(pathlib.Path(train_gt_dir).glob('*'))
train_gt_paths=[str(i) for i in train_gt_paths]
val_hazy_paths=list(pathlib.Path(val_hazy_dir).glob('*'))
val_hazy_paths=[str(i) for i in val_hazy_paths]
val_gt_paths=list(pathlib.Path(val_gt_dir).glob('*'))
val_gt_paths=[str(i) for i in val_gt_paths]
train_hazy_paths.sort()
train_gt_paths.sort()
val_hazy_paths.sort()
val_gt_paths.sort()
# print(train_hazy_paths)
# print(train_gt_paths)
# print(val_hazy_paths)
# print(val_gt_paths)
return (train_hazy_paths,train_gt_paths),(val_hazy_paths,val_gt_paths)
class Dehazing_Dataset(Dataset): # data sample: {'image':image,'landmarks':landmarks}
def __init__(self,hazy_paths,gt_paths,transform=None):
super(Dehazing_Dataset, self).__init__()
self.hazy_paths=hazy_paths
self.gt_paths=gt_paths
self.transform=transform
def __len__(self): # nums of data
return len(self.hazy_paths)
def __getitem__(self, item): # get a sample
hazy_img=io.imread(self.hazy_paths[item])/255.0 # <class 'numpy.ndarray'>
gt_img=io.imread(self.gt_paths[item])/255.0
if self.transform:
hazy_img=self.transform(hazy_img)
gt_img=self.transform(gt_img)
return hazy_img,gt_img
def get_dataset():
(train_hazy_paths, train_gt_paths), (val_hazy_paths, val_gt_paths) = read_imgs_paths()
train_dataset = Dehazing_Dataset(train_hazy_paths, train_gt_paths,
transform=transforms.Compose([transforms.ToTensor(),transforms.RandomCrop(size=(256,287)),]))
val_dataset = Dehazing_Dataset(val_hazy_paths, val_gt_paths,
transform=transforms.Compose([transforms.ToTensor(),transforms.RandomCrop(size=(256,287)),]))
# for i in range(len(train_dataset)):
# sample=train_dataset[i]
# show_img(sample)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True,num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=True,num_workers=0)
return train_dataloader,val_dataloader
def show_img(sample):
hazy,gt=sample[0],sample[1] # c,h,w
hazy = hazy.permute(1, 2, 0)
gt = gt.permute(1, 2, 0)
plt.figure(figsize=(10,15))
for i in range(2):
plt.subplot(1,2,i+1)
plt.axis('off')
if i==0:
plt.imshow(hazy)
else:
plt.imshow(gt)
plt.show()
if __name__=='__main__':
train_dataloader, val_dataloader=get_dataset()
for i_batch,sample_batch in enumerate(train_dataloader):
print(type(sample_batch)) # <class 'list'>
print(sample_batch[0].size()) # torch.Size([2, 3, 256, 287])
print(sample_batch[1].size()) # torch.Size([2, 3, 256, 287])
注意:
1.Pytorch读取图像数据的集中方式,可参考:链接: https://blog.csdn.net/qq_43665602/article/details/126281393
2.使用torchvision.io和PIL两种方式读取的数据范围为[0,255],并未进行归一化,我们可根据自己的需求对其进行归一化。
- 方式一:transform.ToTensor()会自行将数据范围归一化为[0,1];
- 方式二:transform.Normalize(mean,std)可通过调整合适的参数值得到自己想要的归一化结果;