YOLO8添加facial landmark和Head Pose的评价逻辑

2023-11-13


前言

这里主要记录基于yolo8pose的改动过程.


一、如何在val.py中添加NME的逻辑

#添加计算NME的函数
def compute_nme(lms_pred, lms_gt, norm):
    lms_pred = lms_pred.reshape((-1, 2))
    lms_gt = lms_gt.reshape((-1, 2))
    # nme = np.mean(np.linalg.norm(lms_pred - lms_gt, axis=1)) / norm 
    nme = np.sum(np.linalg.norm(lms_pred - lms_gt, axis=1)) / (norm * 68)  # from hrnet
    return nme

# 添加打印关键点和人脸框的函数,方便分析误差大的case
def show_results(img, xyxy, conf, landmarks, class_num, color):

    if color > 7:
        color = 7
    colors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(255,255,255), (0,255,255), (0,255,255), (0,255,255)]
    # 1----- green color
    # 4----- red color
    
    h,w,c = img.shape
    tl = 1 or round(0.002 * (h + w) / 2) + 1  # line/font thickness
    x1 = int(xyxy[0])
    y1 = int(xyxy[1])
    x2 = int(xyxy[2])
    y2 = int(xyxy[3])
    cv2.rectangle(img, (x1,y1), (x2, y2), colors[1], thickness=tl, lineType=cv2.LINE_AA)  


    # for i in range(5):
    for i in range(67):
        point_x = int(landmarks[3 * i])
        point_y = int(landmarks[3 * i + 1])
        # cv2.circle(img, (point_x, point_y), tl+1, colors[i%5], -1)
        cv2.circle(img, (point_x, point_y), tl+1, colors[color], -1)

    tf = max(tl - 1, 1)  # font thickness
    label = str(conf)[:5]
    cv2.putText(img, label, (x1, y1 - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
    return img

# 根据归一化后的tensor 重新转为cv2画图的格式
def tensor2img(input_tensor):
	# 复制一份
	input_tensor = batch['img'][si].clone().detach()
	# 到cpu
	input_tensor = input_tensor.to(torch.device('cpu'))
	# 去掉批次维度
	input_tensor = input_tensor.squeeze()
	input_tensor = input_tensor.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).type(torch.uint8).numpy()
	orgImg = cv2.cvtColor(input_tensor, cv2.COLOR_RGB2BGR)

                
def update_metrics(self, preds, batch):
    #在该函数中添加计算NME的逻辑
	####################################################################### 
            norm_indices = [36, 45]
            norm = np.linalg.norm(tkpts.reshape(-1, 2)[norm_indices[0]].cpu() - tkpts.reshape(-1, 2)[norm_indices[1]].cpu())

            targetcoord = tkpts.cpu()  #torch.Size([1, 68, 3])

            predncoord = np.zeros((68, 2))  #torch.Size([1, 68, 3])---->(68,2)

            num_std = 10000000
            for i in range(len(pred)):
                predncoord[:,0] = predn[i][6::3].cpu()  
                predncoord[:,1] = predn[i][7::3].cpu() 
                tempNME = compute_nme(torch.from_numpy(predncoord), targetcoord[0,:,0:2], norm)
                if tempNME < num_std:
                    num_std = tempNME

            if num_std > 0.3:
                print(num_std)
                num_std =  0

                # method A
                fileName = batch['im_file'][si]
                orgImg = cv2.imread(fileName)
                
                conf_max = 0
                for n in range(len(pred)):

                    xyxy = predn[n, :4].view(-1).tolist()
                    conf = predn[n, 4].cpu().numpy()
                    if conf > conf_max:
                        conf_max = conf
                        max_index = n 
                    
                landmarks = predn[max_index, 6:].view(-1).tolist()  #204
                # class_num = det[j, 15].cpu().numpy()
                class_num = predn[max_index, 5].cpu().numpy()

                #method 1: to show predicted result with all points.
                orgImg = show_results(orgImg, xyxy, conf, landmarks, class_num, color = 4)

                tbox  = tbox[0].view(-1).tolist()
                tkptLabel = torch.zeros(204)
                tkptLabel[0::3] = targetcoord[:,:, 0]  #  targetcoord torch.Size([1, 68, 3])
                tkptLabel[1::3] = targetcoord[:,:, 1]
                tkpts = tkptLabel.view(-1).tolist()

                # #method 1: to show ground-truth with all points.
                orgImg = show_results(orgImg, tbox, 1, tkpts, class_num, color = 1)

                cv2.imwrite(os.fspath(self.save_dir) + '/'+ fileName.split('/')[-1], orgImg)
                # cv2.imshow(fileName.split('/')[-1], orgImg)
                ##########################################################################

            self.nmes_std.append(num_std)
            ########################################################################

二、在val.py中添加Angle Eorror的逻辑

这里不是直接计算误差,而是先把图片name和对应的欧拉角都保存到txt中,以准备第二阶段分析;因为我们没有把angle的数据放入到yolo中.

1.引入库

#保存路经
save_txt_path = os.path.join(os.fspath(self.save_dir) + 'ypr.txt')

#获取当前pic和对应的landmark
landmarks = predn[max_index, 6:].view(-1).tolist()  #204
# class_num = det[j, 15].cpu().numpy()
class_num = predn[max_index, 5].cpu().numpy()

#放入headpose函数,计算欧拉角
rotation_vector, translation_vector = draw_headpose(orgImg, landmarks)  

#保存为对应的txt格式
with open(save_txt_path, 'a') as f_txt:
    # for  i, v in enumerate(nmes_std):
    rotation_vector = ' '.join(str(e) for e in rotation_vector.flatten())
    translation_vector = ' '.join(str(e) for e in translation_vector.flatten())
    f_txt.write(fileName.split('/')[-1] + " " +str(rotation_vector) + " " +str(translation_vector)  + '\n') 

三、将AFLW2000转为yolo格式

1.参考ultralyticsFaceMark/process300LP2Yolo3D.py

处理单个文件的逻辑

def improve_process_aflw2000(root_folder, folder_name, image_name, label_name, target_size):

    image_path = os.path.join(root_folder, folder_name, image_name)
    label_path = os.path.join(root_folder, folder_name, label_name)

    with open(label_path, 'r') as ff:
        C = sio.loadmat(label_path)
        anno = C['pt3d_68'] 

        anno_x = anno[0]
        anno_y = anno[1]

        bbox_xmin = min(anno_x)
        bbox_ymin = min(anno_y)
        bbox_xmax = max(anno_x)
        bbox_ymax = max(anno_y)

        bbox_width = bbox_xmax - bbox_xmin + 1
        bbox_height = bbox_ymax - bbox_ymin + 1
        
        image = cv2.imread(image_path)
        image_height, image_width, _ = image.shape
        bbox_xcenter = bbox_xmin + bbox_width/2
        bbox_ycenter = bbox_ymin + bbox_height/2


        # #We have to check whether the translation is right, So we have to draw the landmarks on the image.
        # colors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(255,255,255), (0,255,255), (0,255,255), (0,255,255)]
        # # 1----- green color
        # # 4----- red color
        
        # h,w,c = image.shape
        # tl = 1 or round(0.002 * (h + w) / 2) + 1  # line/font thickness
        # x1 = int(bbox_xmin)
        # y1 = int(bbox_ymin)
        # x2 = int(bbox_xmax)
        # y2 = int(bbox_ymax)
        # cv2.rectangle(image, (x1,y1), (x2, y2), colors[1], thickness=tl, lineType=cv2.LINE_AA)  

        # # for i in range(5):
        # for i in range(68):
        #     point_x = int(anno_x[i])
        #     point_y = int(anno_y[i])
        #     # cv2.circle(img, (point_x, point_y), tl+1, colors[i%5], -1)
        #     cv2.circle(image, (point_x, point_y), tl+1, colors[0], -1)
            
        # tf = max(tl - 1, 1)  # font thickness
        # label = str(1)
        # cv2.putText(image, label, (x1, y1 - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
        # cv2.imwrite(image_path.split('.')[0].split('/')[-1] + "_68.jpg", image)


        isCrowdAndXYWH = [0, bbox_xcenter/image_width, bbox_ycenter/image_height, bbox_width/image_width, bbox_height/image_height]
        
        anno2 = [[x/image_width, y/image_height, 2] for x,y in np.c_[anno_x,anno_y]]

        return image, isCrowdAndXYWH, anno2

处理文件夹的逻辑

if data_name == 'AFLW2000':
        # folders_train = ['AFW', 'AFW_Flip','HELEN', 'HELEN_Flip','IBUG', 'IBUG_Flip','LFPW', 'LFPW_Flip']
        folders_test = ['AFLW2000']
        annos_test = {}
        for folder_test in folders_test:
            all_files = sorted(os.listdir(os.path.join(root_folder, data_name)))
            image_files = [x for x in all_files if '.jpg'  in x]
            # label_files = [x for x in all_files if '.mat' in x]
            label_files = [x.split('.')[0]+'.mat' for x in all_files if '.mat' in x]
            assert len(image_files) == len(label_files)
            for image_name, label_name in zip(image_files, label_files):
                
                image_crop, isCrowdAndXYWH, anno = improve_process_aflw2000(os.path.join(root_folder), folder_test, image_name, label_name, target_size)               
                image_crop_name = image_name
                # cv2.imwrite(os.path.join(root_folder, 'images', 'test', image_crop_name), image_crop)  #写过一遍,不用再写
                annos_test[image_crop_name] =   isCrowdAndXYWH, anno

        # step 1: 写目录文件
        # with open(os.path.join(root_folder, 'test2yolo.txt'), 'w') as f:
        #     for image_crop_name, anno in annos_test.items():
        #         f.write('./images/test/' + image_crop_name)   #./images/val2017/000000345356.jpg
        #         f.write('\n')

        # step 2: 写单独的标记文件
        for image_crop_name, anno in annos_test.items():
            base_txt = os.path.basename(image_crop_name.split('.')[0]) + ".txt"
            save_txt_path = os.path.join(root_folder,'labels', 'test', base_txt)
            with open(save_txt_path, 'w') as f_txt:
                for xywh in anno[0]:
                    f_txt.write(str(xywh)+' ')
                for x, y, z in anno[1]:
                    f_txt.write(str(x)+' '+str(y)+' '+str(z)+' ')
                f_txt.write('\n') 


总结

进一步工作,在AFLW-2000上验证.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

YOLO8添加facial landmark和Head Pose的评价逻辑 的相关文章

  • 在 Python 中对数据进行求和

    Given that the fitting function is of type 我打算将这样的函数拟合到我拥有的实验数据 x y f x 中 但后来我有一些疑问 当涉及求和时 如何定义拟合函数 一旦定义了函数 即def func re
  • Pandas 将行中的非空值获取到一个单元格中[重复]

    这个问题在这里已经有答案了 给定以下数据框 a pd DataFrame A 1 2 B 4 0 C 1 2 a A B C 0 1 4 1 1 2 0 2 我想创建一个新专栏D包含由列分隔的非空值 每行 像这样 A B C D 0 1 4
  • python中热图的层次聚类

    我有一个 NxM 矩阵 其值范围为 0 到 20 我可以使用 Matplotlib 和 pcolor 轻松获得热图 现在我想使用 scipy 应用层次聚类和树状图 我想重新排序每个维度 行和列 以显示哪些元素相似 根据聚类结果 如果矩阵是方
  • ValueError:“连接”层需要具有匹配形状的输入(连接轴除外)

    我正在尝试为我的项目构建 Pix2Pix 并收到错误 值错误 Concatenate层需要具有匹配形状的输入 除了连接轴之外 获得输入形状 None 64 64 128 None 63 63 128 生成器是一个 U 网模型 我的输入高度
  • 创建一个打开文件并创建字典的函数

    我有一个正在处理的文件 我想创建一个读取文件并将内容放入字典中的函数 然后该字典需要通过 main 函数传递 这是主程序 它无法改变 我所做的一切都必须与主程序配合 def main sunspot dict file str raw in
  • 通过鼻子测试检查某个函数是否发出警告

    我正在使用编写单元测试nose http somethingaboutorange com mrl projects nose 0 11 2 我想检查函数是否引发警告 该函数使用warnings warn 这是很容易就能做到的事情吗 def
  • 将 matplotlib png 转换为 base64 以在 html 模板中查看

    背景 你好 我正在尝试制作一个简单的网络应用程序 按照教程计算阻尼振动方程 并将结果的 png 返回到 html 页面 然后将其转换为 Base64 字符串 Problem 该应用程序运行正常 只是在计算结果时返回损坏的图像图标 可能是因为
  • 使用opencv计算深度视差图

    我无法使用 opencv 从视差图计算深度 我知道两个立体图像中的距离是用以下公式计算的z baseline focal disparity p 但我不知道如何使用地图计算视差 我使用的代码如下 为我提供了两个图像的视差图 import n
  • 在加载“cv2”二进制扩展期间检测到递归

    我有一个小程序 在 pyinstaller 编译后返回 opencv 错误 但无需编译即可工作 我在 Windows 10 上使用 Python 3 8 10 Program 导入 pyautogui将 numpy 导入为 np导入CV2
  • 在 keras 中使用自定义张量流操作

    我在张量流中有一个脚本 其中包含自定义张量流操作 我想将代码移植到 keras 但我不确定如何在 keras 代码中调用自定义操作 我想在 keras 中使用tensorflow 所以到目前为止我发现的教程描述了与我想要的相反的内容 htt
  • 属性错误:类型对象“图像”没有属性“打开”

    Exception in Tkinter callback Traceback most recent call last File C Python34 lib tkinter init py line 1482 in call retu
  • 当 DetailView 遇到时更新模型字段。 [姜戈]

    我有一个类似的 DetailViewviews py views py class CustomView DetailView context object name content model models AppModel templa
  • 向 Python 2.6 添加 SSL 支持

    我尝试使用sslPython 2 6 中的模块 但我被告知它不可用 安装OpenSSL后 我重新编译2 6 但问题仍然存在 有什么建议么 您安装了 OpenSSL 开发库吗 我必须安装openssl devel例如 在 CentOS 上 在
  • 管理文件字段当前 url 不正确

    在 Django 管理中 只要有 FileField 编辑页面上就会有一个 当前 框 其中包含指向当前文件的超链接 但是 此链接会附加到当前页面 url 因此会导致 404 因为不存在这样的页面 例如 http 127 0 0 1 8000
  • 如何从数据框的单元格中获取值?

    我构建了一个条件 从我的数据框中提取一行 d2 df df l ext l ext df item item df wn wn df wd 1 现在我想从特定列中获取一个值 val d2 col name 但结果 我得到一个包含一行和一列
  • 如何在matplotlib中基于x轴更改直方图颜色

    我有根据 pandas 数据框计算出的直方图 我想根据 x 轴值更改颜色 例如 If the value is 0 the color should be green If the value is gt 0 the color shoul
  • Python 垃圾收集有时在 Jupyter Notebook 中不起作用

    我的一些 Jupyter 笔记本经常出现 RAM 不足的情况 而且我似乎无法释放不再需要的内存 这是一个例子 import gc thing Thing result thing do something thing None gc col
  • Flask SQLAlchemy 与 MyPy - 模型类型错误

    我遇到了以下组合问题flask sqlalchemy and mypy 当我定义一个新的 ORM 对象时 例如 class Foo db Model pass where db是使用创建的数据库SQL炼金术应用于flask app mypy
  • 网页抓取 - 如何识别网页上的主要内容

    给定一个新闻文章网页 来自任何主要新闻来源 例如时报或彭博社 我想识别该页面上的主要文章内容 并丢弃其他杂项元素 例如广告 菜单 侧边栏 用户评论 在大多数主要新闻网站上都可以使用的通用方法是什么 有哪些好的数据挖掘工具或库 最好是基于Py
  • python中匹配3个或更多相同的字符

    我正在尝试使用正则表达式在字符串中查找三个或更多相同的字符 例如 你好 不匹配 噢 会的 我尝试过做类似的事情 re compile 1 3 a zA Z re compile w 1 5 但似乎都不起作用 w 1 2 是您正在寻找的正则表

随机推荐

  • JeeSite数据权限控制解决方案

    支持如下数据范围设置 所有数据 所在公司及以下数据 所在公司数据 所在部门及以下数据 所在部门数据 仅本人数据 按明细设置 特殊情况下 跨机构授权 User user UserUtils getUser 使用标准查询 DetachedCri
  • 【python设置临时环境变量】export PYTHONPATH=$(pwd):${PYTHONPATH}

    PYTHONPATH是Python搜索路径 默认我们import的模块都会从PYTHONPATH里面寻找 打印PYTHONPATH import os print sys path gt usr local lib python2 7 di
  • 机器学习之朴素贝叶斯

    朴素贝叶斯 贝叶斯法则 条件独立 如果P X Y Z P X Z P Y Z 或等价地P X Y Z P X Z 则称事件X Y对于给定事件Z是条件独立的 也就是说 当Z发生时 X发生与否与Y发生与否是无关的 朴素贝叶斯 假设每个输入变量独
  • Qt + OpenGL 教程(六):旋转的几种方法(自动旋转、键盘控制、鼠标控制旋转)

    总结了几种旋转的方法 自动旋转 利用计时器 每隔一段时间重新绘制屏幕 实现旋转 键盘控制 点击某个按键 旋转某一角度 鼠标控制 围绕y轴 跟随鼠标旋转 目前只是围绕y轴旋转 不能按任意轴旋转 代码分别为 以后补充
  • ORB_SLAM2 源码解析 特征匹配 (五)

    目录 一 单目初始化中的特征匹配SearchForInitialization 二 跟踪 TrackwithModel TrackReferenceKeyFrame 三 词袋介绍BoW 1 直观理解词袋 2 词袋基本思想 3 从字典结构到k
  • 使用D3.js实现框选节点并进行多节点拖动

    最近再使用d3 js关系图形展示时 需要选中多节点并进行拖动 一开始并不知道D3提供了此API 下面是我结合项目业务整理的框选操作的重点方面的应用 这是d3提供的api 使用鼠标或触摸选择一维或二维区域 可参考示例 https blockb
  • Unity 使用 Dotween 的 Sequence 制作UI动画并且可重复利用

    目录 前言 一 DOTween是什么 二 使用步骤 1 导入DOTween 2 配置DOTween 3 使用代码编写动画 4 代码API解释 总结 前言 DOTween可以制作简易的UI动画 避免创建大量的Animator 本篇文章介绍一下
  • Spring Boot + k8s 最佳实践

    前言 K8s Spring Boot实现零宕机发布 健康检查 滚动更新 优雅停机 弹性伸缩 Prometheus监控 配置分离 镜像复用 配置 健康检查 健康检查类型 就绪探针 readiness 存活探针 liveness 探针类型 ex
  • 书店管理系统

    设计一个书店管理系统 能完成书店的日常管理工作 要求完成的基本功能 1 进货入库记录 2 销售出货记录 3 图书信息查询 可通过书名 作者等途径查询某本图书的详细信息 含书名 作者 出版社 页数 最新入库时间 库存量 价格等 4 自动预警提
  • 时间和日期

    Boost使用的timer和data timerj进行对应和时间日期相关的出来文档 timer包含三个组件 分别为timer progress timer以及对应的progress display timer timer可以测量运行时间 t
  • ROS系统基本功能的使用详解(基本指令/节点/服务/启动文件/动态参数)

    ROS系统基本功能的使用详解 一 创建工作空间 二 创建与编译ROS功能包 三 ROS的基本命令 3 1 节点 3 2 主题 3 3 服务 3 4 参数服务器 四 节点的创建与运行 4 1 创建源文件 4 2 修改CMakeLists tx
  • 域名+七牛云+PicGo+pypora

    域名 七牛云 PicGo pypora 前提准备 域名 自己的域名 七牛云 免费注册申请10G空间够用 picGo 地址 pypora 自行下载 GO 七牛云 注册 gt 登录 gt 控制台 找到对象存储 新建自己空间 绑定域名 添加域名自
  • STM32使用SPI通信驱动2.4G无线射频模块发送数据

    目录 SPI介绍 SPI接口原理 SPI工作原理 SPI特征 引脚配置 结构体 库函数 SPI配置过程 SPI h SPI c NRF24L01无线射频模块 NRF24L01厂家驱动代码移植 NRF24L01 h NRF24L01 c ma
  • 分析一个别人的qt+opengl例子

    Qt5 OpenGL学习笔记 用Qt封装的QOpenGL系列绘制有颜色有深度的三角形 最近学习OpenGL 虽然说Qt可以使用原生OpenGL的API 但是Qt也提供了封装的QOpenGL系列 我用原生的和封装的分别实现了一次简单渲染 都是
  • 竞赛 基于卷积神经网络的乳腺癌分类 深度学习 医学图像

    文章目录 1 前言 2 前言 3 数据集 3 1 良性样本 3 2 病变样本 4 开发环境 5 代码实现 5 1 实现流程 5 2 部分代码实现 5 2 1 导入库 5 2 2 图像加载 5 2 3 标记 5 2 4 分组 5 2 5 构建
  • Python Web系列学习2-Django

    1 django admin Django项目管理工具 建立一个Django项目用 django admin startproject xxx 生成的站点目录结构为 2 进入站点目录 建立一个应用 python manage py star
  • Qt基础篇:Qt读取路径下所有文件或指定类型文件(含递归、判断是否为空、创建路径)

    文件路径的拆解 QFileInfo fileinfo QString file full ui gt m AlgorithmFilePathLineEdit gt text qDebug lt lt file full 输出1 filein
  • Java框架体系架构的知识,分享一点面试小经验

    前言 当前我们都会说SpringBoot是Spring框架对 约定优先于配置理念的最佳实践的产物 一个典型的SpringBoot应用本质上其实就是一个基于Spring框架的应用 而如果大家对Spring框架已经了如指掌 那么 在我们一步步揭
  • Python实现截图——附完整源码

    Python实现截图 附完整源码 为了能在日常工作中方便地截取并保存屏幕截图 我们可以利用Python编写一段代码实现这个功能 本文将介绍基于Windows平台下的Python截图实现方法 包括如何使用Python的Pillow模块以及py
  • YOLO8添加facial landmark和Head Pose的评价逻辑

    目录 TOC 目录 前言 一 如何在val py中添加NME的逻辑 二 在val py中添加Angle Eorror的逻辑 1 引入库 三 将AFLW2000转为yolo格式 1 参考ultralyticsFaceMark process3