拉取modelscope镜像
sudo docker pull registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py37-torch1.11.0-tf1.15.5-1.1.2
拉取增强代码
git clone https://gitee.com/binghai228/Bringing-Old-Photos-Back-to-Life.git
cd Bringing-Old-Photos-Back-to-Life
安装增强模块:
cd Face_Enhancement/models/networks/
git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm .
cd ../../../
cd Global/detection_models
git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm .
cd ../../
cd Face_Detection/
wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
bzip2 -d shape_predictor_68_face_landmarks.dat.bz2
cd ../
cd Face_Enhancement/
wget https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life/releases/download/v1.0/face_checkpoints.zip
unzip face_checkpoints.zip
cd ../
cd Global/
wget https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life/releases/download/v1.0/global_checkpoints.zip
unzip global_checkpoints.zip
cd ../
创建容器:
sudo nvidia-docker run -dit --net=host --name beauty --shm-size="1g" -v $PWD:/beauty registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py37-torch1.11.0-tf1.15.5-1.1.2 bash
进入容器:
sudo docker exec -it beauty bash
cd beauty
安装依赖:
pip install -r requirements.txt
测试修复:
python run.py --input_folder test \
--output_folder result \
--GPU 0 \
--with_scratch \
--HR
从官网下载redis稳定版到项目根目录下。
tar -zxvf redis-7.0.7.tar.gz
在dockers环境中编译redis:
cd redis-7.0.7
make
编译完成后安装redis:
cd src
make install
启动redis:
redis-server &
测试连接:
redis-cli
redis 127.0.0.1:6379> ping
正常输出如下:
PONG
想要查看所有键值:
keys *
删除某个键值:
del key1
断开连接:
exit
其他相关命令如下:
redis-server stop
redis-server start
redis-server restart
安装数据库驱动:
pip install redis
安装fastapi:
pip install fastapi -i https://mirror.baidu.com/pypi/simple
pip install "uvicorn[standard]" -i https://mirror.baidu.com/pypi/simple
pip install python-multipart -i https://mirror.baidu.com/pypi/simple
后台代码:
import numpy as np
import cv2
import base64
from fastapi import FastAPI, File, UploadFile, Depends
import time
import uuid
import os
from fastapi import Request
import json
from redis import Redis,ConnectionPool
from tool import base64_encode_image,base64_decode_image
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def get_rdb():
pool = ConnectionPool(host='127.0.0.1',port = 6379)
rdb = Redis(connection_pool=pool)
try:
yield rdb
finally:
rdb.close()
@app.post("/beauty/")
async def get_image(request: Request,file: UploadFile =File(...), db:Redis=Depends(get_rdb)):
print("接收到数据")
imgdata = await file.read()
imgdata = np.frombuffer(imgdata, np.uint8)
image = cv2.imdecode(imgdata, cv2.IMREAD_COLOR)
h,w,c = image.shape
k = str(uuid.uuid4())
in_data = {"in_img": base64_encode_image(image.astype('float32')),"height":h,"width":w,"channel":c}
db.set(k, json.dumps(in_data))
db.expire(k, 50)
time_out_num = 0
while True:
output = db.get(k)
output = json.loads(output.decode("utf-8"))
if output.get('out_img') is not None:
im = base64_decode_image(output["out_img"],"float32",(output.get('height'), output.get('width'),output.get('channel')))
db.delete(k)
_, buffer_img = cv2.imencode('.jpg', im)
img64 = base64.b64encode(buffer_img)
img64 = str(img64, encoding='utf-8')
print("完成")
return {"state":1 ,"img": img64}
time.sleep(1)
time_out_num = time_out_num + 1
if (time_out_num > 50):
db.delete(k)
print('超时')
return {"state":-1 ,"code": '当前拥挤,请稍候再试'}
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
启动fastapi:
uvicorn main_fastapi:app --host 0.0.0.0 --port 8041 &
其中tool.py文件函数如下:
import numpy as np
import base64
def base64_encode_image(a):
"""将numpy数组进行base64编码"""
return base64.b64encode(a).decode("utf-8")
def base64_decode_image(a, dtype, shape):
"""base64转图像"""
a = bytes(a, encoding="utf-8")
a = np.frombuffer(base64.decodestring(a), dtype=dtype)
a = a.reshape(shape)
return a
模型循环推理代码:
import os
import numpy as np
import json
import cv2
import base64
import time
import redis
from tool import base64_encode_image,base64_decode_image
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
if __name__ == "__main__":
db = redis.StrictRedis(host="localhost", port=6379, db=0)
skin_retouching = pipeline(Tasks.skin_retouching,model='damo/cv_unet_skin-retouching')
print('美颜模型开始运行')
while True:
keys = db.keys()
isfinder = False
for k in keys:
output = db.get(k)
if output is None:
continue
output = json.loads(output.decode("utf-8"))
if output.get('out_img') is None:
isfinder = True
in_img = base64_decode_image(output["in_img"],"float32",(output.get('height'), output.get('width'),output.get('channel')))
break
else:
continue
if not isfinder:
time.sleep(0.2)
continue
cv2.imwrite('./in/test.jpg',in_img)
result = skin_retouching('./in/test.jpg')
cv2.imwrite('./out/result.png', result[OutputKeys.OUTPUT_IMG])
out_img = cv2.imread('./out/result.png',cv2.IMREAD_COLOR)
out_data = {"out_img": base64_encode_image(out_img.astype('float32'))}
output.update(out_data)
db.set(k, json.dumps(output))
print("完成推理")
time.sleep(0.2)
本地测试脚本client_beauty.py:
import numpy as np
import requests
import cv2
import base64
url = "http://127.0.0.1:8041/beauty"
files = {'file':open('1.jpg','rb'),}
result = requests.post(url=url, files=files)
result = result.json()
img = result["img"]
img = bytes(img, encoding='utf-8')
img = base64.b64decode(img)
img = np.asarray(bytearray(img), dtype="uint8")
img = cv2.imdecode(img, cv2.IMREAD_UNCHANGED)
if img is None:
print('调用失败')
else:
cv2.imwrite('result.png',img)
print('完成')
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)