pytorch的自定义接口是真的方便, 记录一下自己分割数据输入的脚本:
# -*- coding: utf-8 -*-
# @Time : 2019/10/31 21:36
# @Author : Yunyun Xu
# @Contact : 1443563995@qq.com
# @File : MyDatasetReader.py
# @Software: Pycharm
# @Blog : https://me.csdn.net/xuyunyunaixuexi
import os
import numpy as np
import scipy.misc as m
from PIL import Image
from torch.utils import data
from mypath import Path
from torchvision import transforms
import custom_transforms as tr
class MyEggSegmentation(data.Dataset):
#NUM_CLASSES = 19
def __init__(self, args, root = Path.db_root_dir("MyEggs"), split = "train"):
self.root = root
self.split = split
self.args = args
self.image_files = {}
self.label_files = {}
#files = {train:[]}
self.images_base = os.path.join(self.root, 'leftImg8bit', self.split)
self.annotations_base = os.path.join(self.root, 'gtFine_trainvaltest', 'gtFine', self.split)
self.image_files[split] = self.recursive_glob(rootdir=self.images_base, suffix=".png")
self.label_files[split] = self.recursive_glob(rootdir = self.annotations_base,suffix = ".png")
if not self.image_files[split]:
raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
print("Found %d %s images" % (len(self.files[split]), split))
def __len__(self):
return len(self.image_files[self.split])
def __getitem__(self, index):
img_path = self.image_files[self.split][index].rstrip()
lbl_path = self.label_files[self.split][index].rstrip()
#将RGBA转为RGB三通道
_img = Image.open(img_path).convert("RGB")
#读取索引图
_target = Image.open(lbl_path)
sample = {"images":_img, "label":_target}
if self.split == "train":
return self.transform_tr(sample)
if self.split == "val":
return self.transform_tr(sample)
if self.split == "test":
return self.transform_tr(sample)
def recursive_glob(self, rootdir = '.', suffix = " "):
return [os.path.join(looproot, filename) for looproot, _, filenames in os.walk(rootdir)
for filename in filenames if filename.endswith(suffix)]
def transform_tr(self, sample):
composed_transforms = transforms.Compose([
tr.RandomHorizontalFlip(),
tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255),
tr.RandomGaussianBlur(),
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
tr.ToTensor()])
return composed_transforms(sample)
def transform_val(self, sample):
composed_transforms = transforms.Compose([
tr.FixScaleCrop(crop_size=self.args.crop_size),
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
tr.ToTensor()])
return composed_transforms(sample)
def transform_ts(self, sample):
composed_transforms = transforms.Compose([
tr.FixedResize(size=self.args.crop_size),
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
tr.ToTensor()])
return composed_transforms(sample)
测试了一下,是可以遍历的,证明自定义数据集接口(继承data.Dataset)是正确的:
但是本人也有一个问题, 就是分割网络如果根据自己的数据集大小,去确定crop_size, 还是只能不断的尝试去效果??希望得到大家的解答.........