在 PyTorch 中原生测量多类分类的 F1 分数

2024-04-12

我正在尝试在 PyTorch 中本地实现宏 F1 分数(F-measure),而不是使用已经广泛使用的sklearn.metrics.f1_score https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html以便直接在 GPU 上计算测量值。

据我了解,为了计算宏 F1 分数,我需要计算所有标签的灵敏度和精度的 F1 分数,然后取所有这些的平均值。

我的尝试

我当前的实现如下所示:

def confusion_matrix(y_pred: torch.Tensor, y_true: torch.Tensor, n_classes: int):
    conf_matrix = torch.zeros([n_classes, n_classes], dtype=torch.int)
    y_pred = torch.argmax(y_pred, 1)
    for t, p in zip(y_true.view(-1), y_pred.view(-1)):
        conf_matrix[t.long(), p.long()] += 1
    return conf_matrix

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    conf_matrix = confusion_matrix(y_pred, y_true, self.classes)
    TP = conf_matrix.diag()
    f1_scores = torch.zeros(self.classes, dtype=torch.float)
    for c in range(self.classes):
        idx = torch.ones(self.classes, dtype=torch.long)
        idx[c] = 0
        FP = conf_matrix[c, idx].sum()
        FN = conf_matrix[idx, c].sum()
        sensitivity = TP[c] / (TP[c] + FN + self.epsilon)
        precision = TP[c] / (TP[c] + FP + self.epsilon)
        f1_scores[c] += 2.0 * ((precision * sensitivity) / (precision + sensitivity + self.epsilon))
    return f1_scores.mean()

self.classes是标签的数量,self.epsilon是一个非常小的值设置为10-e12这可以防止DivisionByZeroError.

训练时,我计算每批的测量值,并将所有测量值的平均值作为最终分数。

Problem

问题是,当我将自定义 F1 分数与 sklearn 宏 F1 分数进行比较时,它们很少相等。

# example 1
eval_cce 0.5203, eval_f1 0.8068, eval_acc 81.5455, eval_f1_sci 0.8023,
test_cce 0.4784, test_f1 0.7975, test_acc 82.6732, test_f1_sci 0.8097
# example 2
eval_cce 0.3304, eval_f1 0.8211, eval_acc 87.4955, eval_f1_sci 0.8626,
test_cce 0.3734, test_f1 0.8183, test_acc 85.4996, test_f1_sci 0.8424
# example 3
eval_cce 0.4792, eval_f1 0.7982, eval_acc 81.8482, eval_f1_sci 0.8001,
test_cce 0.4722, test_f1 0.7905, test_acc 82.6533, test_f1_sci 0.8139

虽然我尝试扫描互联网,但大多数情况都涉及二进制分类。我还没有找到一个例子来尝试做我想做的事情。

我的问题

我的尝试有什么明显的问题吗?

更新(2020年6月10日)

我还没有弄清楚我的错误。由于时间限制,我决定只使用 sklearn 提供的 F1 宏分数。虽然它不能直接与 GPU 张量一起工作,但无论如何对于我的情况来说它已经足够快了。

然而,如果有人能够解决这个问题,那就太棒了,这样任何其他可能偶然发现这个问题的人都可以解决他们的问题。


我前段时间在 Pytorch 中编写了自己的实现:

from typing import Tuple

import torch


class F1Score:
    """
    Class for f1 calculation in Pytorch.
    """

    def __init__(self, average: str = 'weighted'):
        """
        Init.

        Args:
            average: averaging method
        """
        self.average = average
        if average not in [None, 'micro', 'macro', 'weighted']:
            raise ValueError('Wrong value of average parameter')

    @staticmethod
    def calc_f1_micro(predictions: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Calculate f1 micro.

        Args:
            predictions: tensor with predictions
            labels: tensor with original labels

        Returns:
            f1 score
        """
        true_positive = torch.eq(labels, predictions).sum().float()
        f1_score = torch.div(true_positive, len(labels))
        return f1_score

    @staticmethod
    def calc_f1_count_for_label(predictions: torch.Tensor,
                                labels: torch.Tensor, label_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Calculate f1 and true count for the label

        Args:
            predictions: tensor with predictions
            labels: tensor with original labels
            label_id: id of current label

        Returns:
            f1 score and true count for label
        """
        # label count
        true_count = torch.eq(labels, label_id).sum()

        # true positives: labels equal to prediction and to label_id
        true_positive = torch.logical_and(torch.eq(labels, predictions),
                                          torch.eq(labels, label_id)).sum().float()
        # precision for label
        precision = torch.div(true_positive, torch.eq(predictions, label_id).sum().float())
        # replace nan values with 0
        precision = torch.where(torch.isnan(precision),
                                torch.zeros_like(precision).type_as(true_positive),
                                precision)

        # recall for label
        recall = torch.div(true_positive, true_count)
        # f1
        f1 = 2 * precision * recall / (precision + recall)
        # replace nan values with 0
        f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1).type_as(true_positive), f1)
        return f1, true_count

    def __call__(self, predictions: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Calculate f1 score based on averaging method defined in init.

        Args:
            predictions: tensor with predictions
            labels: tensor with original labels

        Returns:
            f1 score
        """

        # simpler calculation for micro
        if self.average == 'micro':
            return self.calc_f1_micro(predictions, labels)

        f1_score = 0
        for label_id in range(1, len(labels.unique()) + 1):
            f1, true_count = self.calc_f1_count_for_label(predictions, labels, label_id)

            if self.average == 'weighted':
                f1_score += f1 * true_count
            elif self.average == 'macro':
                f1_score += f1

        if self.average == 'weighted':
            f1_score = torch.div(f1_score, len(labels))
        elif self.average == 'macro':
            f1_score = torch.div(f1_score, len(labels.unique()))

        return f1_score

您可以通过以下方式进行测试:

from sklearn.metrics import f1_score
import numpy as np
errors = 0
for _ in range(10):
    labels = torch.randint(1, 10, (4096, 100)).flatten()
    predictions = torch.randint(1, 10, (4096, 100)).flatten()
    labels1 = labels.numpy()
    predictions1 = predictions.numpy()

    for av in ['micro', 'macro', 'weighted']:
        f1_metric = F1Score(av)
        my_pred = f1_metric(predictions, labels)
        
        f1_pred = f1_score(labels1, predictions1, average=av)
        
        if not np.isclose(my_pred.item(), f1_pred.item()):
            print('!' * 50)
            print(f1_pred, my_pred, av)
            errors += 1

if errors == 0:
    print('No errors!')
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 PyTorch 中原生测量多类分类的 F1 分数 的相关文章

  • 存储为 np.arrays 的不同数据集的分组堆积条形图

    我正在研究一个平衡问题 我想比较一些数据 我想通过创建不同年份的堆叠条形图来做到这一点 每年 我想要两个不同数据集的堆叠条形图 我正在尝试创建一种 分组堆积条形图 我设法创建了我想要比较的 2 个堆叠条形图 但它们仍然位于两个不同的图中 我
  • 在 Windows 上使用 Python 打开设备句柄

    我正在尝试使用 Giveio sys 驱动程序 该驱动程序需要先打开一个 文件 然后才能访问受保护的内存 我正在查看 WinAVR AVRdude 中的 C 示例 它使用以下语法 define DRIVERNAME giveio HANDL
  • 不要在异常堆栈中显示 Python raise-line

    当我在 Python 库中引发自己的异常时 异常堆栈将引发行本身显示为堆栈的最后一项 这显然不是一个错误 在概念上是正确的 但是当您在外部使用代码 例如作为模块 时 它会将重点放在对调试无用的东西上 有没有办法避免这种情况并强制 Pytho
  • AWS Lambda - 在区域之间自动复制 EC2 快照?

    我想创建一个 Lambda 函数 python 它将自动将已创建的快照复制到另一个区域 我已联系 AWS Support 他们只向我发送了用于 RDS 数据库的 GitHub 脚本 没有 EC2 快照复制脚本 任何帮助都会很棒 谢谢 是的
  • 使用自定义元素类在 Python 中解析 xml

    我想使用 Python 的 xml etree ElementTree 模块解析 xml 文档 但是 我希望生成的树对象中的所有元素都具有我定义的一些类方法 这建议创建我自己的 Python 元素类的子类 但我无法告诉解析器在解析时使用我自
  • 在 AWS Elastic Beanstalk 中部署 Flask 应用程序

    当我部署 Flask 应用程序时 它显示成功 但是当我检索日志时 我看到错误 找不到 Flask 我的需求文件中有烧瓶 任何帮助 Sat Jan 11 06 51 50 503908 2020 error pid 3393 remote 1
  • 使用Python mysql.connector远程连接MySQL

    以下代码 在同一 LAN 内与 mysql 服务器不同的机器上运行 使用 Python3 和 mysql connector 本地连接到 MySQL 数据库 import mysql connector cnx mysql connecto
  • 如何在 PyCharm 中启用 flake8 的自动代码格式化

    我使用 Tox 运行单元测试 并使用 flake8 命令检查代码格式错误 每次我在 PyCharm 中编码时 我都会运行 tox 然后意识到我有一堆烦人的格式错误 我必须返回并手动修复 我希望 PyCharm 自动格式化代码 根据 flak
  • 使用底图和Python在地图中绘制海洋

    我正在绘制此处提供的 netCDF 文件 https goo gl QyUI4J https goo gl QyUI4J Using the code below the map looks like this 然而 我希望海洋是白色的 更
  • OpenCV 在使用 anaconda 的 Linux 上无法与 python 正常工作。收到 cv2.imshow() 未实现的错误

    这就是我得到的确切错误 我的操作系统是 Ubuntu 16 10 OpenCV 错误 未指定错误 该功能未实现 使用 Windows GTK 2 x 或 Carbon 支持重新构建库 如果您使用的是 Ubuntu 或 Debian 请安装
  • 如何在 Python 中重命名文件并保留创建日期

    我知道创建日期不存储在文件系统本身中 但是当我使用时我遇到了问题os rename 它正在更新我正在使用的文件的创建日期 是否可以重命名文件而不更改其原始创建日期 正如都铎所说 你可以使用os stat http docs python o
  • 在 Django 1.9 中使用信号

    在 Django 1 8 中 我能够使用信号执行以下操作 一切顺利 init py from signals import 信号 py receiver pre save sender Comment def process hashtag
  • Python 字符串参数解析

    我正在 python 中使用 cmd 类 它将所有参数作为一个大字符串传递给我 将此 arg 字符串标记为 args 数组的最佳方法是什么 Example args arg arg1 arg2 with quotes arg4 arg5 1
  • SyntaxError:多个异常类型必须用括号括起来

    我是初学者 在使用 python 安装 pycaw 进行音频控制后遇到问题 在放置 pycaw 的基本初始化代码时 出现以下错误 Traceback most recent call last File c Users volumeCont
  • 插入失败“OperationalError:没有这样的列”

    我尝试使用我尝试修复的姓名和电话创建一个数据库 但它会随时向我重播 File exm0 py line 14 in
  • Python - Map/Reduce - 如何在使用 DISCO 计数单词示例中读取 JSON 特定字段

    我正在按照 DISCO 示例来计算文件中的单词数 将单词数作为 Map Reduce 作业 http discoproject org doc disco start tutorial html 我对此工作没有任何问题 但是我想尝试从包含
  • Python 中的“lambda”是什么意思,最简单的使用方法是什么?

    您能否给出一个示例和其他示例来说明何时以及何时不使用 Lambda 我的书给了我一些例子 但它们很令人困惑 拉姆达 起源于拉姆达演算 http en wikipedia org wiki Lambda calculus和 AFAIK 首先实
  • jupyter run magic 将参数传递给笔记本

    当您在第一个 jupyter 笔记本 first ipynb 中时 您可以执行第二个 但如何传递参数呢 假设第二个有以下内容 xx 10 您可以从第一个调用第二个 如下所示 run second ipynb xx will print 10
  • 交响二阶颂歌

    我有一个简单的二阶 ODE 的齐次解 当我尝试使用 Sympy 求解初始值时 它返回相同的解 它应该替代 y 0 和 y 0 并产生一个没有常数的解 但事实并非如此 这是建立方程的代码 它是一个弹簧平衡方程 k 弹簧常数 m 质量 我在其他
  • 在大型文本文件中查找重复记录

    我在一台 Linux 机器 Redhat 上 并且有一个 11GB 的文本文件 文本文件中的每一行包含单个记录的数据 并且该行的前 n 个字符包含该记录的唯一标识符 该文件包含略多于 2700 万条记录 我需要验证文件中不存在具有相同唯一标

随机推荐