Yolov5改进之更改损失函数(EIOU、SIOU)

2023-11-14

目录

1、修改metrics.py文件

2、修改loss.py函数


1、修改metrics.py文件

找到bbox_iou代码段:

​
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
    # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)

    """在ComputeLoss的__call__函数中调用计算回归损失
       :params box1: 预测框
       :params box2: 预测框
       :return box1和box2的IoU/GIoU/DIoU/CIoU
       """
    # Get the coordinates of bounding boxes
    if xywh:  # transform from xywh to xyxy
        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
    else:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
        w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
        w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)

    # Intersection area tensor.clamp(0):将矩阵中小于0的元数变成0
    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
            (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps

    # IoU
    iou = inter / union
    if CIoU or DIoU or GIoU:
        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # 两个框的最小闭包区域的width convex (smallest enclosing box) width
        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # 两个框的最小闭包区域的height convex height
        if CIoU or DIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = cw ** 2 + ch ** 2 + eps  # convex diagonal squared
            rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4  # center dist ** 2
            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
                with torch.no_grad():
                    alpha = v / (v - iou + (1 + eps))
                return iou - (rho2 / c2 + v * alpha)  # CIoU
            return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        return iou - (c_area - union) / c_area  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    return iou  # IoU

​

使用下列代码替换上述代码段:

def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, SIoU=False, EIoU=False, Focal=False, alpha=1, gamma=0.5, eps=1e-7):
    # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)

    # Get the coordinates of bounding boxes
    if xywh:  # transform from xywh to xyxy
        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
    else:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
        w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
        w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)

    # Intersection area
    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
            (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps

    # IoU
    # iou = inter / union # ori iou
    iou = torch.pow(inter/(union + eps), alpha) # alpha iou
    if CIoU or DIoU or GIoU or EIoU or SIoU:
        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
        if CIoU or DIoU or EIoU or SIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = (cw ** 2 + ch ** 2) ** alpha + eps  # convex diagonal squared
            rho2 = (((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4) ** alpha  # center dist ** 2
            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
                with torch.no_grad():
                    alpha_ciou = v / (v - iou + (1 + eps))
                if Focal:
                    return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)), torch.pow(inter/(union + eps), gamma)  # Focal_CIoU
                else:
                    return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha))  # CIoU
            elif EIoU:
                rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2
                rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2
                cw2 = torch.pow(cw ** 2 + eps, alpha)
                ch2 = torch.pow(ch ** 2 + eps, alpha)
                if Focal:
                    return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2), torch.pow(inter/(union + eps), gamma) # Focal_EIou
                else:
                    return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2) # EIou
            elif SIoU:
                # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
                s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps
                s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps
                sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
                sin_alpha_1 = torch.abs(s_cw) / sigma
                sin_alpha_2 = torch.abs(s_ch) / sigma
                threshold = pow(2, 0.5) / 2
                sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
                angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
                rho_x = (s_cw / cw) ** 2
                rho_y = (s_ch / ch) ** 2
                gamma = angle_cost - 2
                distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
                omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
                omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
                shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
                if Focal:
                    return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha), torch.pow(inter/(union + eps), gamma) # Focal_SIou
                else:
                    return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha) # SIou
            if Focal:
                return iou - rho2 / c2, torch.pow(inter/(union + eps), gamma)  # Focal_DIoU
            else:
                return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        if Focal:
            return iou - torch.pow((c_area - union) / c_area + eps, alpha), torch.pow(inter/(union + eps), gamma)  # Focal_GIoU https://arxiv.org/pdf/1902.09630.pdf
        else:
            return iou - torch.pow((c_area - union) / c_area + eps, alpha)  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    if Focal:
        return iou, torch.pow(inter/(union + eps), gamma)  # Focal_IoU
    else:
        return iou  # IoU

2、修改loss.py函数

找到下面的代码:

iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze()  # iou(prediction, target)

替换CIOU为EIOU:

iou = bbox_iou(pbox, tbox[i], EIoU=True, alpha=1).squeeze()

注意:以EIOU为例,当 EIOU=True,alphaIOU>1时,则损失函数是两者的结合,为 a-eiou。

           当alphaIOU=1时,则损失函数就是EIOU。

则添加成功!

SIOU同理!!

大家也可以尝试着添加其他损失函数

推荐up主:一个非常棒的B站up主,都去看!!

YOLOV5改进-添加EIOU,SIOU,AlphaIOU._哔哩哔哩_bilibili


加油!每天学会一点点!!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Yolov5改进之更改损失函数(EIOU、SIOU) 的相关文章

随机推荐

  • 成功解决Windows MemoryError: Unable to allocate 6.38 GiB for an array with shape (38

    因为运行文件所在的磁盘分配内存不够问题造成的 解决方法如下 打开我的电脑 右键属性 高级 性能设置 选择高级 更改 点击E盘 点击自定义大小 设置分配内存 我选择6G 6144kb 点击确定完成 再次运行文件 问题解决
  • invalid credential, access_token is invalid or not latest hint(微信 上传图片返回 error)

    errcode 40001 errmsg invalid credential access token is invalid or not latest hint 3G1y5a0106vr61 这种情况跟这个库没有直接关系 请检查一下是否
  • 5分钟讲解直流线性稳压降压电源基本原理

    怎么把 12 v电变为 5v呢 通过变压器是可以实现的 但是变压器只能转换交流电 那直流电怎么转换呢 我们来看下最简单的降压方式 比如负载是 5欧 那么要得到 5V的压降 按照串联分压原理 需要给它串联一个 7 欧的电阻附加 就能得到 5
  • 【LINUX相关】生成随机数(srand、/dev/random 和 /dev/urandom )

    目录 一 问题背景 二 修改方法 2 1 修改种子 2 2 使用linux中的 dev urandom 生成随机数 三 dev random 和 dev urandom 的原理 3 1 参考连接 3 2 重难点总结 3 2 1 生成随机数的
  • 9*9乘法表

    package practice 99乘法表 public class Test02 public static void main String args for int i 1 i lt 9 i 外层控制行数 for int j 1 j
  • 【查缺补漏】“.“ 和 “->“运算符的区别是什么?

    目录 简介 Note 结语 简介 Hello 非常感谢您阅读海轰的文章 倘若文中有错误的地方 欢迎您指出 昵称 海轰 标签 程序猿 C 选手 学生 简介 因C语言结识编程 随后转入计算机专业 获得过国家奖学金 有幸在竞赛中拿过一些国奖 省奖
  • shell test功能

    test测试功能 对于要测试系统上面某些文件或其相关属性时 可以使用test进行测试 test会根据相关功能返回True或False 测试文件类型test e filename 测试功能 意义 e 该文件是否存在 f 该文件名是否存在且为文
  • SSD1306 - OLED显示屏

    SSD1306 OLED显示屏 芯片介绍 引脚介绍 SSD1306是一款带控制器的用于OLED点阵图形显示系统的单片CMOS OLED PLED驱动器 它由128个SEG 列输出 和64个COM 行输出 组成 该芯片专为共阴极OLED面板设
  • 数据的无量纲化处理和标准化处理的区别是什么

    数据的无量纲化处理和标准化处理的区别是什么 请教 两者除了方法上有所不同外 在其他方面还有什么区别 解答 标准化处理方法是无量纲化处理的一种方法 除此之外 还有相对化处理方法 包括初值比处理 函数化 功效系数 方法 等等 由于标准化处理方法
  • C++11智能指针之unique_ptr

    1 智能指针概念 智能指针是基于RAII机制实现的类 模板 具有指针的行为 重载了operator 与operator gt 操作符 可以 智能 地销毁其所指对象 C 11中有unique ptr shared ptr与weak ptr等智
  • 英语 动词过去式和过去分词的变化规则

    动词过去式和过去分词有规则变化和不规则变化两种 实例顺序 动词原形过去式过去分词 发音 ed在清辅音音素后发音为 t 在浊辅音后发音为 d 在元音后发音也为 d 在 t d 后发音为 id 一 规则变化 1 一般在动词原形后加 ed loo
  • 嵌套查询及其与join的区别

    嵌套即可以写在select子句中 也可以写在from子句中 下面以SQL Entity为例来说明 1 嵌套在select中 以父表为主在select中嵌套子表信息 SELECT c Title ANYELEMENT SELECT oa Fi
  • mysql联合索引最左匹配原则的底层实现原理

    mysql联合索引最左匹配原则的底层实现原理 要看懂 需要熟悉mysql b tree的数据结构 b tree的叶节点和叶子节点的排序特性是按照 从小到大 从左到右的这么一个规则 int直接比大小 uuid比较ASCII码 联合索引的排序规
  • 极简式 Unity 获取 bilibili 直播弹幕、SC、上舰、礼物等 插件

    极简式 Unity 获取 bilibili 直播弹幕 SC 上舰 礼物等 1 声明 下载链接 软件均仅用于学习交流 请勿用于任何商业用途 2 介绍 该项目为Unity实时爬取B站直播弹幕 项目介绍 通过传入B站直播间账号 实现监控B站直播弹
  • java threadlocal 详解_Java中的ThreadLocal深入理解详解

    提到 ThreadLocal是什么 ThreadLocal是一个关于创建线程局部变量的类 通常情况下 我们创建的变量是可以被任何一个线程访问并修改的 而使用ThreadLocal创建的变量只能被当前线程访问 其他线程则无法访问和修改 Glo
  • sha256的python实现

    在 Python 中可以使用 hashlib 库来实现 SHA 256 哈希算法 代码如下 import hashlib defsha256 data sha256 hashlib sha256 sha256 update data enc
  • 为什么printf只能用_cdecl调用约定

    1 什么是调用约定 调用约定 Calling conventions 和type representations 名称修饰 name mangling 同是应用二进制接口 application binary interface ABI 概
  • 分析rocketmq-client产生大量rocketmq_client.log日志文件的原因处理方案

    源码 public static final String CLIENT LOG USESLF4J rocketmq client logUseSlf4j public static final String CLIENT LOG ROOT
  • 05 两层神经网络 - 神经网络和深度学习 [Deep Learning Specialization系列]

    本文是Deep Learning Specialization系列课程的第1课 Neural Networks and Deep Learning 中Shallow Neural Network部分的学习笔记 在前面的章节中 我们以逻辑回归
  • Yolov5改进之更改损失函数(EIOU、SIOU)

    目录 1 修改metrics py文件 2 修改loss py函数 1 修改metrics py文件 找到bbox iou代码段 def bbox iou box1 box2 xywh True GIoU False DIoU False