Tensorflow2.0—DeepLab v3+分割网络原理及代码解析(四)- 训练过程
在Tensorflow2.0—DeepLab v3+分割网络原理及代码解析(三)- 特征提取网络实现中,输入图片已经经过主干网络进行了特征提取,最终得到的fearture map的shape为(512,512,2)。
这篇主要讲讲训练过程吧~~
一、dataset
train_dataloader = DeeplabDataset(train_lines, input_shape, batch_size, num_classes, True, VOCdevkit_path)[3]
val_dataloader = DeeplabDataset(val_lines, input_shape, batch_size, num_classes, False, VOCdevkit_path)
train.py中上述两行代码开启dataset的重定义~
def __getitem__(self, index):
images = [] #保存img
targets = [] #保存label
for i in range(index * self.batch_size, (index + 1) * self.batch_size):
i = i % self.length
name = self.annotation_lines[i].split()[0] #从第i批batch数据的第一张图开始遍历处理
#-------------------------------#
# 从文件中读取图像(返回的都是Image对象,而不是数组)
#-------------------------------#
jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/JPEGImages"), name + ".jpg"))
png = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/SegmentationClass"), name + ".png"))
#-------------------------------#
# 数据增强
#-------------------------------#
jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train)
jpg = preprocess_input(np.array(jpg, np.float64)) #输入图片做归一化操作
png = np.array(png)
png[png >= self.num_classes] = self.num_classes
#-------------------------------------------------------#
# 转化成one_hot的形式
# 在这里需要+1是因为voc数据集有些标签具有白边部分
# 我们需要将白边部分进行忽略,+1的目的是方便忽略。
#-------------------------------------------------------#
seg_labels = np.eye(self.num_classes + 1)[png.reshape([-1])]
seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))
images.append(jpg)
targets.append(seg_labels)
images = np.array(images)
targets = np.array(targets)
return images, targets
- images:一个列表,包含了batch_size个图片,每个图片都经过了归一化的操作,shape均为(512,512,3)
- targets:一个列表,包含了batch_size个label,shape均为(img_h,img_w,class_num+1)
二、loss
直接看大佬的blog吧~
憨批的语义分割重制版10——Tensorflow2 搭建自己的DeeplabV3+语义分割平台