【项目实战】pytorch载入训练好的模型并进行可视化模型预测绘图

2023-05-16

main函数载入模型,加载图片,输出结果:

if __name__ == '__main__':
  image =  Image.open(r"C:\Users\pic\test\he_5.jpg")
    image =transform(image).unsqueeze(0)
    modelme = torch.load('modefresnet.pkl')
    modelme.eval() #表示将模型转变为evaluation(测试)模式,这样就可以排除BN和Dropout对测试的干扰。
     visualize_model(modelme)
    outputs = modelme(image)
    _, predict = torch.max(outputs.data, 1)
        for j in range(image.size()[0]):
     print('predicted: {}'.format(class_names[predict[j]]))

对图片的统一处理transform:

transform=transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],
                                 std=[0.229,0.224,0.225])
                            ])

对于预测结果进行可视化的函数:

def visualize_model(model, num_images=6):
 was_training = model.training
 model.eval()
 images_so_far = 0
 fig = plt.figure()
 with torch.no_grad():

     #for i, (inputs, labels) in enumerate(dataloaders['val']):
     for i, (inputs, labels) in enumerate(testloder):
       outputs = model(inputs)
       _, preds = torch.max(outputs, 1)
       
       for j in range(inputs.size()[0]):

           images_so_far += 1

           ax = plt.subplot(num_images // 2, 2, images_so_far)

           ax.axis('off')

           ax.set_title('predicted: {}'.format(class_names[preds[j]]))

           imshow(inputs.cpu().data[j])


           if images_so_far == num_images:

            model.train(mode=was_training)
            plt.show()

            return

     model.train(mode=was_training)

载入一新的图片数据集:

data_dir =os.getcwd() + '\\data\\'
dataloadertest =datasets.ImageFolder(os.path.join(data_dir, "tt"),transform=transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],
                                 std=[0.229,0.224,0.225])
                            ]) )
testloder = torch.utils.data.DataLoader(dataloadertest,batch_size = 4,shuffle = True)

目录结构:
在这里插入图片描述
其中要注意传入的图片的预处理:
image = Image.open(r"C:\Users\pic\test\he_5.jpg")
image =transform(image).unsqueeze(0)
需为PIL格式,且需先进行转化才能传入模型。

结果:
在这里插入图片描述

在这里插入图片描述
经测试之后不论是传入单张图片还是一个新数据集结果均符合预期。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【项目实战】pytorch载入训练好的模型并进行可视化模型预测绘图 的相关文章

随机推荐