如何编写混淆矩阵

2024-02-08

我用Python写了一个混淆矩阵计算代码:

def conf_mat(prob_arr, input_arr):
    # confusion matrix
    conf_arr = [[0, 0], [0, 0]]

    for i in range(len(prob_arr)):
        if int(input_arr[i]) == 1:
            if float(prob_arr[i]) < 0.5:
                conf_arr[0][1] = conf_arr[0][1] + 1
            else:
                conf_arr[0][0] = conf_arr[0][0] + 1
        elif int(input_arr[i]) == 2:
            if float(prob_arr[i]) >= 0.5:
                conf_arr[1][0] = conf_arr[1][0] +1
            else:
                conf_arr[1][1] = conf_arr[1][1] +1

    accuracy = float(conf_arr[0][0] + conf_arr[1][1])/(len(input_arr))

prob_arr是我的分类代码返回的数组,示例数组如下所示:

 [1.0, 1.0, 1.0, 0.41592955657342651, 1.0, 0.0053405015805891975, 4.5321494433440449e-299, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.70943426182688163, 1.0, 1.0, 1.0, 1.0]

input_arr是数据集的原始类标签,如下所示:

[2, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 2, 1, 1, 1]

我的代码试图做的是:我明白了prob_arr and input_arr对于每个类别(1 和 2),我检查它们是否分类错误。

但我的代码只适用于两个类。如果我为多分类数据运行此代码,它不起作用。我怎样才能为多个班级做这个?

例如,对于具有三个类的数据集,它应该返回:[[21, 7, 3], [3, 38, 6],[5, 4, 19]].


Scikit-Learn 提供了confusion_matrix功能

from sklearn.metrics import confusion_matrix

y_actu = [2, 0, 2, 2, 0, 1, 1, 2, 2, 0, 1, 2]
y_pred = [0, 0, 2, 1, 0, 2, 1, 0, 2, 0, 2, 2]
confusion_matrix(y_actu, y_pred)

输出一个 Numpy 数组

array([[3, 0, 0],
       [0, 1, 2],
       [2, 1, 3]])

但您也可以使用 Pandas 创建混淆矩阵:

import pandas as pd

y_actu = pd.Series([2, 0, 2, 2, 0, 1, 1, 2, 2, 0, 1, 2], name='Actual')
y_pred = pd.Series([0, 0, 2, 1, 0, 2, 1, 0, 2, 0, 2, 2], name='Predicted')
df_confusion = pd.crosstab(y_actu, y_pred)

您将获得一个(带有精美标签的)Pandas DataFrame:

Predicted  0  1  2
Actual
0          3  0  0
1          0  1  2
2          2  1  3

如果你添加margins=True like

df_confusion = pd.crosstab(y_actu, y_pred, rownames=['Actual'], colnames=['Predicted'], margins=True)

您还将获得每行和每列的总和:

Predicted  0  1  2  All
Actual
0          3  0  0    3
1          0  1  2    3
2          2  1  3    6
All        5  2  5   12

您还可以使用以下方法获得归一化混淆矩阵:

df_confusion = pd.crosstab(y_actu, y_pred)
df_conf_norm = df_confusion.div(df_confusion.sum(axis=1), axis="index")

Predicted         0         1         2
Actual
0          1.000000  0.000000  0.000000
1          0.000000  0.333333  0.666667
2          0.333333  0.166667  0.500000

您可以使用以下命令绘制这个混淆矩阵

import matplotlib.pyplot as plt


def plot_confusion_matrix(df_confusion, title='Confusion matrix', cmap=plt.cm.gray_r):
    plt.matshow(df_confusion, cmap=cmap) # imshow
    #plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(df_confusion.columns))
    plt.xticks(tick_marks, df_confusion.columns, rotation=45)
    plt.yticks(tick_marks, df_confusion.index)
    #plt.tight_layout()
    plt.ylabel(df_confusion.index.name)
    plt.xlabel(df_confusion.columns.name)


df_confusion = pd.crosstab(y_actu, y_pred)
plot_confusion_matrix(df_confusion)

或使用以下方法绘制归一化混淆矩阵:

plot_confusion_matrix(df_conf_norm)  

您可能也对此项目感兴趣https://github.com/pandas-ml/pandas-ml https://github.com/pandas-ml/pandas-ml及其 Pip 包https://pypi.python.org/pypi/pandas_ml https://pypi.python.org/pypi/pandas_ml

有了这个包,混淆矩阵就可以被漂亮地打印、绘图。 您可以对混淆矩阵进行二值化,获取类别统计信息,例如 TP、TN、FP、FN、ACC、TPR、FPR、FNR、TNR (SPC)、LR+、LR-、DOR、PPV、FDR、FOR、NPV 和一些总体统计数据统计数据

In [1]: from pandas_ml import ConfusionMatrix
In [2]: y_actu = [2, 0, 2, 2, 0, 1, 1, 2, 2, 0, 1, 2]
In [3]: y_pred = [0, 0, 2, 1, 0, 2, 1, 0, 2, 0, 2, 2]
In [4]: cm = ConfusionMatrix(y_actu, y_pred)
In [5]: cm.print_stats()
Confusion Matrix:

Predicted  0  1  2  __all__
Actual
0          3  0  0        3
1          0  1  2        3
2          2  1  3        6
__all__    5  2  5       12


Overall Statistics:

Accuracy: 0.583333333333
95% CI: (0.27666968568210581, 0.84834777019156982)
No Information Rate: ToDo
P-Value [Acc > NIR]: 0.189264302376
Kappa: 0.354838709677
Mcnemar's Test P-Value: ToDo


Class Statistics:

Classes                                        0          1          2
Population                                    12         12         12
P: Condition positive                          3          3          6
N: Condition negative                          9          9          6
Test outcome positive                          5          2          5
Test outcome negative                          7         10          7
TP: True Positive                              3          1          3
TN: True Negative                              7          8          4
FP: False Positive                             2          1          2
FN: False Negative                             0          2          3
TPR: (Sensitivity, hit rate, recall)           1  0.3333333        0.5
TNR=SPC: (Specificity)                 0.7777778  0.8888889  0.6666667
PPV: Pos Pred Value (Precision)              0.6        0.5        0.6
NPV: Neg Pred Value                            1        0.8  0.5714286
FPR: False-out                         0.2222222  0.1111111  0.3333333
FDR: False Discovery Rate                    0.4        0.5        0.4
FNR: Miss Rate                                 0  0.6666667        0.5
ACC: Accuracy                          0.8333333       0.75  0.5833333
F1 score                                    0.75        0.4  0.5454545
MCC: Matthews correlation coefficient  0.6831301  0.2581989  0.1690309
Informedness                           0.7777778  0.2222222  0.1666667
Markedness                                   0.6        0.3  0.1714286
Prevalence                                  0.25       0.25        0.5
LR+: Positive likelihood ratio               4.5          3        1.5
LR-: Negative likelihood ratio                 0       0.75       0.75
DOR: Diagnostic odds ratio                   inf          4          2
FOR: False omission rate                       0        0.2  0.4285714

我注意到一个关于混淆矩阵的新 Python 库名为PyCM http://www.pycm.ir/已经出来了:也许你可以看看。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何编写混淆矩阵 的相关文章

随机推荐

  • laravel dusk TeaDown() 必须与 Illuminate\Foundation\Testing\TestCase::tearDown() 兼容

    public function tearDown this gt browse function Browser browser browser gt click navbarDropdown gt click dropdown item
  • 这些嵌套向量是如何连接的?

    我编写了一段代码 它创建了一个向量 记分板 其中包含 3 个大小为 3 的独立向量 所有向量都包含符号 在所有索引 0 2 处 当我现在执行 向量集 时在记分牌的第一个向量上 要将其第一个元素更改为 X 向量 2 和 3 也会更改 这是如何
  • 防止在 Javascript 中自动创建全局变量

    我刚刚花了一些时间调试一个问题 归根结底是忘记使用var关键字位于新变量标识符前面 因此 Javascript 会自动在全局范围内创建该变量 有什么方法可以防止这种情况发生 或者更改默认行为 而不使用像 JSLint 这样的验证器 在编写和
  • 如何仅在第一次启动时显示视图?

    我使用 Xcode 4 5 和故事板构建了一个应用程序 第一次启动应用程序时 我希望初始视图控制器出现 并附带必须接受才能继续的条款和条件 之后 我希望应用程序启动并跳过第一个视图控制器并转到第二个视图控制器 我知道我必须使用 NSUser
  • Android 4.3 BTLE作为服务器:如何启动广告?

    我正在尝试使用 4 3 中的新 BTLE API 在 Nexus 7 上实现 BTLE 服务器 我遇到了几个问题 首先 SDK 中没有示例 唯一的例子是针对客户的 其次 文档实际上告诉你做错误的事情 它指出 人们必须使用BluetoothA
  • 如何检测 MemoryMappedFile 是否正在使用

    在 C 4 0 中 MemoryMappedFile有几种工厂方法 CreateFromFile CreateNew CreateOrOpen or OpenExisting 我需要打开MemoryMappedFile如果存在 则从文件创建
  • Gitlab docker 和 external_url

    你好 我使用 docker 安装了最新的 gitlab 我使用 p 10080 80 和 10022 22 启动容器 我可以浏览 gitlab 并执行我需要的操作 我什至可以分别使用端口 10080 和 10022 git 克隆 http
  • 如何在android webview中启用默认突出显示菜单?

    如何在 android webview 中启用默认文本突出显示菜单 例如 复制 粘贴 搜索 共享 在 Android 1 5 2 3 上工作 您可以使用emulateShiftHeld 自 2 2 起公开 但现在已弃用 此方法将您的 Web
  • 使用 'hd' 参数限制 Google OAuth 访问一个域 (Django / python-social-auth)

    我正在构建一个内部网络应用程序供我的公司使用 并希望使用我们的 Google Apps 域来管理来自我们公司域用户名的访问 本问题的其余部分为 example com 我在用着 Django 1 9 5 python social auth
  • 如何在日期字段上显示日期选择器日历

    这是关于如何使用 jQuerydate picker在 django 支持的站点中 models py is from django db import models class holidaytime models Model holid
  • 对数组使用限制?

    有没有办法告诉 C99 编译器我访问给定数组的唯一方法是使用 myarray index 说这样的话 int heavy calcualtions float restrict range1 float restrict range2 fl
  • 为 iPhone 本地化货币

    我希望我的 iPhone 应用程序允许用户使用适当的符号 等 输入 显示和存储货币金额 NSNumberFormatter 会做我需要的一切吗 当用户切换其区域设置并且这些金额 美元 日元等 存储为 NSDecimalNumbers 时会发
  • Java 中 HTML 字符编码的转换

    我们正在尝试下载网页源代码 但是由于字符编码的原因 我们无法正确看到某些特定字符 例如 我们尝试了以下代码来转换字符串 text 变量 的编码 byte xyz text getBytes text new String xyz windo
  • React:搜索和过滤功能存在问题

    我正在开发一个组件 它应该能够 按输入搜索 使用输入字段 在触发 onBlur 事件后将调用一个函数 之后onBlur事件开始寻找 方法将运行 按所选流派过滤 用户可以从其他组件中从流派列表中选择流派 之后onClick事件启动过滤器 方法
  • 使用 Facebook 图表来获取粉丝页面的粉丝?

    我有一个粉丝页面 位于http www facebook com shop4tronix http www facebook com shop4tronix 我可以通过以下方式访问此页面上的信息 http graph facebook co
  • 文本区域 onresize 不起作用

    根据w3schools
  • 返回主菜单不断循环菜单

    当程序第一次启动时 我可以成功地从主菜单中选择任何选项 但是 当我从任何子菜单中选择 返回主菜单 选项时 它都会返回主菜单 但无论我之后再次按哪个选项 它都会继续循环该菜单 只允许我选择返回主菜单选项 如何将选择重置到不会继续循环的位置 我
  • GDB源路径

    如何让gdb使用不同的目录来查找源文件 例如 我在编译期间的源文件位于目录中 home foo bar c 接下来 我将其移动到目录中 tmp debug home foo bar c 如何强制gdb在该目录中搜索 根据这个网站 https
  • 基于有序对多关系对描述符进行排序

    我的核心数据模型的描述 项目和问题实体 项目有一个有序的一对多关系至已命名问题的问题 问题与名为parentProject 的项目具有一对一的关系 这是我获取问题的代码 let fetchRequest NSFetchRequest ent
  • 如何编写混淆矩阵

    我用Python写了一个混淆矩阵计算代码 def conf mat prob arr input arr confusion matrix conf arr 0 0 0 0 for i in range len prob arr if in