图像分割套件PaddleSeg全面解析(八)预测代码解读

2023-10-27

训练完成模型之后,可以对图片进行预测,还可以实现模型结果可视化,查看分割效果。

运行命令如下:

python predict.py \
       --config configs/quick_start/bisenet_optic_disc_512x512_1k.yml \
       --model_path output/iter_1000/model.pdparams \
       --image_path data/optic_disc_seg/JPEGImages/H0003.jpg \
       --save_dir output/result

首先解释一下上面命令的参数含义,

–config指定配置文件,其中包含了模型的名称。

–model_path指定模型路径

–image_path指定输入预测的图片路径

–save_dir指定了输出预测结果保存的路径。

还可以通过以下命令进行多尺度翻转预测。

–aug_pred是否开启增强预测

–scales缩放系数,默认为1.0

–flip_horizontal是否开启水平翻转

–flip_vertical是否开启垂直翻转

多尺度翻转预测是在普通预测的基础上,对输入图片进行多尺度缩放、水平垂直方向翻转等操作,得出多个预测结果,然后将多个预测结果相加作为最后的输出结果。可以通过下图了解一下预测程序的工作流程。

下面我们解读一下predict.py的代码。

if __name__ == '__main__':
	#解析传入参数
    args = parse_args()
    #执行主体函数
    main(args)

我们通过解读parse_args函数来了解一下predict.py脚本支持的输入参数与val.py基本一致。

def parse_args():
    parser = argparse.ArgumentParser(description='Model prediction')

    # params of prediction
    # 配置文件路径
    parser.add_argument(
        "--config", dest="cfg", help="The config file.", default=None, type=str)
    # 训练好的模型权重路径
    parser.add_argument(
        '--model_path',
        dest='model_path',
        help='The path of model for prediction',
        type=str,
        default=None)
    # 输入的预测图片路径
    parser.add_argument(
        '--image_path',
        dest='image_path',
        help=
        'The path of image, it can be a file or a directory including images',
        type=str,
        default=None)
    #输出的保存预测结果路径
    parser.add_argument(
        '--save_dir',
        dest='save_dir',
        help='The directory for saving the predicted results',
        type=str,
        default='./output/result')

    # augment for prediction
    #是否使用多尺度和翻转增强的方式预测。这种方法会带来精度的提升,推荐使用
    parser.add_argument(
        '--aug_pred',
        dest='aug_pred',
        help='Whether to use mulit-scales and flip augment for prediction',
        action='store_true')
    # 指定缩放系数,1.0为保持尺寸不变,可以指定多个系数,用空格隔开。
    parser.add_argument(
        '--scales',
        dest='scales',
        nargs='+',
        help='Scales for augment',
        type=float,
        default=1.0)
    # 开启图片水平翻转
    parser.add_argument(
        '--flip_horizontal',
        dest='flip_horizontal',
        help='Whether to use flip horizontally augment',
        action='store_true')
    #开启图片垂直翻转
    parser.add_argument(
        '--flip_vertical',
        dest='flip_vertical',
        help='Whether to use flip vertically augment',
        action='store_true')

    # sliding window prediction
    #滑动窗口参数配置,是否开启滑动窗口
    parser.add_argument(
        '--is_slide',
        dest='is_slide',
        help='Whether to prediction by sliding window',
        action='store_true')
    # 滑动窗口尺寸
    parser.add_argument(
        '--crop_size',
        dest='crop_size',
        nargs=2,
        help=
        'The crop size of sliding window, the first is width and the second is height.',
        type=int,
        default=None)
    # 滑动窗口移动的步长,需要指定水平方向和垂直方向两个参数。
    parser.add_argument(
        '--stride',
        dest='stride',
        nargs=2,
        help=
        'The stride of sliding window, the first is width and the second is height.',
        type=int,
        default=None)

    return parser.parse_args()

以上是输入参数的解析。在main函数中,主要使用core/predict.py模块中的predict函数对图片进行预测。

首先看一下predict函数的代码概要。

然后对predict函数进行代码解读。

def predict(model,
            model_path,
            transforms,
            image_list,
            image_dir=None,
            save_dir='output',
            aug_pred=False,
            scales=1.0,
            flip_horizontal=True,
            flip_vertical=False,
            is_slide=False,
            stride=None,
            crop_size=None):
    #加载模型权重
    para_state_dict = paddle.load(model_path)
    model.set_dict(para_state_dict)
    #设置模型为评估模式
    model.eval()

    added_saved_dir = os.path.join(save_dir, 'added_prediction')
    pred_saved_dir = os.path.join(save_dir, 'pseudo_color_prediction')

    logger.info("Start to predict...")
    #设置进度条
    progbar_pred = progbar.Progbar(target=len(image_list), verbose=1)
    #遍历图片列表
    for i, im_path in enumerate(image_list):
    	#读取图像
        im = cv2.imread(im_path)
        #获取图像宽高
        ori_shape = im.shape[:2]
        #对图像进行转换
        im, _ = transforms(im)
        #新增一个维度
        im = im[np.newaxis, ...]
        #将ndarray数据转换为张量
        im = paddle.to_tensor(im)
		#是否开启多尺度翻转预测
        if aug_pred:
            #开启多尺度翻转预测,则对图片进行多尺度翻转预测
            pred = infer.aug_inference(
                model,
                im,
                ori_shape=ori_shape,
                transforms=transforms.transforms,
                scales=scales,
                flip_horizontal=flip_horizontal,
                flip_vertical=flip_vertical,
                is_slide=is_slide,
                stride=stride,
                crop_size=crop_size)
        else:
            #如果没有开启多尺度翻转预测,则对图片进行常规的推理预测操作。
            pred = infer.inference(
                model,
                im,
                ori_shape=ori_shape,
                transforms=transforms.transforms,
                is_slide=is_slide,
                stride=stride,
                crop_size=crop_size)
        #将返回数据去除多余的通道,并转为uint8类型,方便保存为图片
        pred = paddle.squeeze(pred)
        pred = pred.numpy().astype('uint8')
		
        #获取保存图片的名称
        # get the saved name
        if image_dir is not None:
            im_file = im_path.replace(image_dir, '')
        else:
            im_file = os.path.basename(im_path)
        if im_file[0] == '/':
            im_file = im_file[1:]
		#保存结果
        added_image = utils.visualize.visualize(im_path, pred, weight=0.6)
        added_image_path = os.path.join(added_saved_dir, im_file)
        mkdir(added_image_path)
        cv2.imwrite(added_image_path, added_image)

		# 保存伪色彩预测结果
        # save pseudo color prediction
        pred_mask = utils.visualize.get_pseudo_color_map(pred)
        pred_saved_path = os.path.join(pred_saved_dir,
                                       im_file.rsplit(".")[0] + ".png")
        mkdir(pred_saved_path)
        pred_mask.save(pred_saved_path)

        # pred_im = utils.visualize(im_path, pred, weight=0.0)
        # pred_saved_path = os.path.join(pred_saved_dir, im_file)
        # mkdir(pred_saved_path)
        # cv2.imwrite(pred_saved_path, pred_im)
		#进度条进度加1
        progbar_pred.update(i + 1)

在上述代码中,根据输入参数不同,则调用不同的推理函数,已经在上一节评估代码解读中已经介绍,这里不再重复。

以上就是PaddleSeg当前版本主要代码的全部解读。

本系列文章也会定期与PaddleSeg版本更新保持同步。因本人水平有限,若有错误之处还请谅解。

PaddleSeg仓库地址:https://github.com/PaddlePaddle/PaddleSeg

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

图像分割套件PaddleSeg全面解析(八)预测代码解读 的相关文章

随机推荐