import tensorflow as tf
import cv2
from PIL import Image
import numpy as np
import colorsys
import os
import matplotlib.pyplot as plt
def resize_image(image, size):
""" 等比例resize """
iw, ih = image.size
w, h = size
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
return new_image, nw, nh
def preprocess_input(image):
image = image / 127.5 - 1
return image
input_shape = (512,512) # 与训练的时候一致
num_classes = 2 # 类别+1
def preProcessing(filepath):
inputs = cv2.imread(filepath)
old_img = Image.open(filepath)
h,w = inputs.shape[0],inputs.shape[1]
# print(f'初始图像size: {h},{w}')
""" 数据预处理 """
image_data, nw, nh = resize_image(old_img, (input_shape[1], input_shape[0]))
image_data = np.expand_dims(preprocess_input(np.array(image_data, np.float32)), 0)
return old_img,(h,w),(nw,nh),image_data
def postProcessing():
""" 对预测结果进行后处理 """
# resize回图像原始的大小
pr = cv2.resize(pr_arrays, (w, h), interpolation = cv2.INTER_LINEAR)
pr = pr.argmax(axis=-1) # 取出每一个像素点的种类
seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
if num_classes <= 21:
colors = [ (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128),
(128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128),
(64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128),
(128, 64, 12)]
else:
hsv_tuples = [(x / num_classes, 1., 1.) for x in range(num_classes)]
colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
for c in range(num_classes):
seg_img[:,:,0] += ((pr[:,: ] == c )*(colors[c][0] )).astype('uint8')
seg_img[:,:,1] += ((pr[:,: ] == c )*(colors[c][1] )).astype('uint8')
seg_img[:,:,2] += ((pr[:,: ] == c )*(colors[c][2] )).astype('uint8')
resultImage = Image.fromarray(np.uint8(seg_img))
image = Image.blend(old_img,resultImage,0.5)
return image
def saveAndShow(image):
savename = os.path.basename(filepath)[:-4]+"httpResult.jpg"
savePath = 'servingOut/'
if not os.path.exists(savePath):
os.mkdir(savePath)
image.save(savePath+savename)
plt.title(os.path.basename(filepath))
plt.imshow(image)
plt.show()
if __name__ == '__main__':
mymodel = tf.saved_model.load("test/1")
while True:
try:
filepath = input('请输入待预测图像路径(输入c退出): ')
if filepath == 'c':
break
old_img,(h,w),(nw,nh),image_data = preProcessing(filepath=filepath)
pr = mymodel(image_data)[0]
pr_arrays = pr.numpy()
image = postProcessing()
saveAndShow(image)
except Exception as e:
print(e)
continue
是在httpClient.py(参考文章)的基础上改的,主要是导入模型和输入data进行推理:
mymodel = tf.saved_model.load("test/1")
pr = mymodel(image_data)[0]
这个pr目前是tensor类型,需要转成numpy,然后才可以进行后处理
pr_arrays = pr.numpy()