计算混淆矩阵的更快方法?

2023-12-12

我正在计算图像语义分割的混淆矩阵,如下所示,这是一种非常冗长的方法:

def confusion_matrix(preds, labels, conf_m, sample_size):
    preds = normalize(preds,0.9) # returns [0,1] tensor
    preds = preds.flatten()
    labels = labels.flatten()
    for i in range(len(preds)):
        if preds[i]==1 and labels[i]==1:
            conf_m[0,0] += 1/(len(preds)*sample_size) # TP
        elif preds[i]==1 and labels[i]==0:
            conf_m[0,1] += 1/(len(preds)*sample_size) # FP
        elif preds[i]==0 and labels[i]==0:
            conf_m[1,0] += 1/(len(preds)*sample_size) # TN
        elif preds[i]==0 and labels[i]==1:
            conf_m[1,1] += 1/(len(preds)*sample_size) # FN 
    return conf_m

在预测循环中:

conf_m = torch.zeros(2,2) # two classes (object or no-object)
for img,label in enumerate(data):
    ...
    out = Net(img)
    conf_m = confusion_matrix(out, label, len(data))
    ...

是否有更快的方法(在 PyTorch 中)来有效计算图像语义分割输入样本的混淆矩阵?


我使用这两个函数来计算混淆矩阵(如其定义)sklearn):

# rewrite sklearn method to torch
def confusion_matrix_1(y_true, y_pred):
    N = max(max(y_true), max(y_pred)) + 1
    y_true = torch.tensor(y_true, dtype=torch.long)
    y_pred = torch.tensor(y_pred, dtype=torch.long)
    return torch.sparse.LongTensor(
        torch.stack([y_true, y_pred]), 
        torch.ones_like(y_true, dtype=torch.long),
        torch.Size([N, N])).to_dense()

# weird trick with bincount
def confusion_matrix_2(y_true, y_pred):
    N = max(max(y_true), max(y_pred)) + 1
    y_true = torch.tensor(y_true, dtype=torch.long)
    y_pred = torch.tensor(y_pred, dtype=torch.long)
    y = N * y_true + y_pred
    y = torch.bincount(y)
    if len(y) < N * N:
        y = torch.cat(y, torch.zeros(N * N - len(y), dtype=torch.long))
    y = y.reshape(N, N)
    return y

y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]

confusion_matrix_1(y_true, y_pred)
# tensor([[2, 0, 0],
#         [0, 0, 1],
#         [1, 0, 2]])

在类数量较少的情况下,第二个函数速度更快。

%%timeit
confusion_matrix_1(y_true, y_pred)
# 102 µs ± 30.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
confusion_matrix_2(y_true, y_pred)
# 25 µs ± 149 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

计算混淆矩阵的更快方法? 的相关文章

  • 如何在 Windows 上使用 Python 3.6 来安装 Python 2.7

    我想问一下如何使用pip install对于 Python 2 7 当我之前安装并使用 Python 3 6 时 我现在必须使用 Windows 上的 Python 版本 pip install 继续安装 Python 3 6 我需要使用以
  • Yocto 如何停止 cmake 在本机 sysroot 路径中查找链接?

    到目前为止 我正在尝试将 dlib python 模块添加到我的图像中 这是我正在研究的食谱 python3 dlib 19 21 1 bb SUMMARY A toolkit for making real world machine l
  • 我正在尝试为 Antlr4 Python3.g4 语法文件生成解析树,以解析 python3 代码

    我正在使用 ANTLR4 并尝试为我拥有的 python 文件生成解析树 我使用了 ANTLR4 文档中的语法文件 python3 g4 我安装了antlr4 python3 runtime 并且运行了以下命令 antlr4 Dlangua
  • 如何以干净高效的方式在 pytorch 中获得小批量?

    我试图做一件简单的事情 即使用火炬通过随机梯度下降 SGD 训练线性模型 import numpy as np import torch from torch autograd import Variable import pdb def
  • Instagram 图表 api 日期之间的媒体帖子

    我正在尝试使用以下方法从我管理的 Instagram Business 个人资料中检索上个月的媒体帖子 since and until 但它似乎无法正常工作 因为 API 返回的帖子超出了我选择的时间范围 我使用以下字符串来调用 API b
  • 地图与星图的性能?

    我试图对两个序列进行纯Python 没有外部依赖 逐元素比较 我的第一个解决方案是 list map operator eq seq1 seq2 然后我发现starmap函数来自itertools 这看起来和我很相似 但事实证明 在最坏的情
  • 避免在列表理解中计算相同的表达式两次[重复]

    这个问题在这里已经有答案了 我在列表理解中使用一个函数和一个 if 函数 new list f x for x in old list if f x 0 令我恼火的是这个表达f x 在每个循环中计算两次 有没有办法以更清洁的方式做到这一点
  • 连接运算符 + 或 ,

    var1 abc var2 xyz print literal var1 var2 literalabcxyz print literal var1 var2 literal abc xyz 除了带有 的自动空格之外 两者有什么区别 哪个通
  • urllib.error.URLError:

    Python 3 4 2 当我在脚本中运行 urllib request urlopen url 时 出现了一个奇怪的错误 如果我直接在 Python 解释器中运行它 它可以正常工作 但当我通过 bash shell Linux 在脚本内运
  • 拓扑错误:无法执行操作“GEOSIntersection_r”

    嗨 大家好 我正在尝试将选区形状文件映射到议会选区 我有 两者 的形状文件 基本上 我必须将人口普查数据中地区级别给出的所有变量映射到议会选区级别 所以我正在关注 pycontalk https github com gramener py
  • 在 while 循环中更改 tkinter 画布中的图像

    我的完整代码是here https gist github com ItsBerry de245ba70376cb07f4dbe2d25c223f5f 我正在尝试使用 tkinter 的画布创建一个小游戏 让人们练习学习高音谱号上的音符 最
  • Matplotlib 在 Ubuntu 18.04 上引发 MemoryError,但在 Windows 10 上则不会

    我正在 Ubuntu 机器上为 Windows 用户开发软件 它能做什么 对数千张图像进行物体检测 并将结果与 一些测量数据进行比较 示波器数据 200MB 5000 万个数据值 最后绘制并保存结果 在此步骤之后 程序将前进到下一个数据集
  • Whatsapp 自动机器人无法在 WhatsApp 联系人列表中搜索

    我正在尝试实现一个 WhatsApp 机器人 它使用chromedriver并打开 Whatsapp 网页 并向联系人发送消息 这些是该程序的步骤 从 Excel 文件中读取联系人信息 设置您想要发送消息的时间以及要发送的消息 搜索该名称并
  • 在 Python 中延迟转置列表

    所以 我有一个延迟生成的可迭代的三元组 我试图弄清楚如何将其转换为 3 个可迭代对象 分别由元组的第一个 第二个和第三个元素组成 然而 我希望这件事能懒惰地完成 所以 举例来说 我希望 1 2 3 4 5 6 7 8 9 将变成 1 4 7
  • 为什么当循环数变大时,设置的打印值会被排序?

    它是python 3 8 当输入10时 打印是随机的 但是当输入900时 打印的顺序与 print sorted s 相同 import random s set for i in range int input loop nums n v
  • 如何从 Python 安全地清除 Gnome 中的两个剪贴板?

    Gnome 桌面有 2 个剪贴板 X org 保存每个选择 和旧版剪贴板 CTRL C 我正在编写一个简单的 python 脚本来清除两个剪贴板 最好是安全地清除 因为它可以在复制粘贴密码后完成 我在这里看到的代码是这样的 empty X
  • 使用 Selenium 从 twitter 抓取动态推文

    这可能看起来像一个重复的问题 但相信我 我在 Twitter 上观察到了一些新东西 我之前制作了一个 Twitter 抓取工具 它使用滚动和等待动态元素来获取给定数量的推文 但现在好像不行了 它不会抓取超过 10 条推文 此外 它抓取的推文
  • lxml 的类型提示?

    Python 新手 具有静态类型语言背景 我想要类型提示https lxml de https lxml de只是为了便于开发 mypy 标记问题并建议方法会很好 据我所知 这是一个 python 2 0 模块 没有类型 目前我用过http
  • 在 groupby 聚合函数中传递参数

    我有我引用的数据框df在代码中 我在每组的多个列上应用聚合函数 我还应用了用户定义的 lambda 函数f4 f5 f6 f7 有些功能非常相似 例如f4 f6 and f7其中只有参数值不同 我可以从以下位置传递这些参数吗字典 d 这样我
  • 通过子类化 `io.TextIOWrapper` 来子类化文件 - 但它的构造函数有什么签名?

    我正在尝试子类化io TextIOWrapper下列的这个帖子 https stackoverflow com a 23796737 974555 虽然我的目标不同 以此开始 注意 动机 https stackoverflow com a

随机推荐

  • 堆栈与堆属性的 QT 特定差异?

    通常 在编写 C 代码时 我会始终将对象保留为普通属性 从而利用 RAII 然而 在 QT 中 删除对象的责任可以由析构函数承担QObject 因此 假设我们定义了一些特定的小部件 那么我们有两种可能性 1 使用QT的系统 class Wi
  • 不允许从一个 Google 电子表格访问另一个 Google 电子表格

    我试图通过其他电子表格中的 onEdit 事件为我的 Google 电子表格设置新值 我收到异常 不允许执行操作 我不明白我到底做错了什么 我会很高兴得到你的帮助 因为我只是在 JS Google Docs 脚本中做第一步 function
  • 将函数与 numpy 数组的每个元素积分作为积分极限

    我在 python 中有一个函数 也使用 scipy 和 numpy 定义为 import numpy as np from scipy import integrate LCDMf lambda x 1 0 np sqrt 0 3 1 x
  • 尝试使用 groupby 查找每月 5 个最大值

    我试图显示前三个值nc type每个月 我尝试使用n largest但这并没有按日期完成 原始数据 area nc type occurred date 0 Filling x 12 23 2015 0 00 1 Filling f 12
  • ddply+summary函数列名输入

    我正在尝试使用ddply and summarise一起从plyr包 但在解析不断变化的列名时遇到困难 在我的示例中 我想要一些能够以编程方式在 X1 中解析的东西 而不是在 X1 中硬编码到 ddply 函数中 举例说明 require
  • Android - 无法隐藏进度条

    因此 我检查了其他问题以隐藏进度条 但所有问题似乎都建议做我已经在做的事情 我正在尝试使用 mProductListProgressBar setVisibility View GONE 我找到它 mProductListProgressB
  • Makefile 更新了库依赖项

    我有一个很大的 makefile 它构建几个库 安装它们 然后继续构建链接到这些已安装库的对象 我的麻烦是我想使用 lfoo lbar 作为 g 标志来链接两个已安装的库 但依赖关系变得混乱 如果我更改库 foo 所依赖的标头 42 h 那
  • 查找 python 的安装位置(如果不是默认目录)

    Python 在我的机器上 我只是不知道在哪里 如果我在终端中输入 python 它将打开 Python 2 6 4 这不在它的默认目录中 肯定有一种方法可以从这里找到它的安装位置 sys有一些有用的东西 python Python 2 6
  • 如何在 3D 空间中围绕 x 轴旋转正方形

    所以我一直在尝试学习 3D 渲染是如何工作的 我尝试编写一个脚本 目标是在 3D 空间中旋转平面 2D 正方形 我首先在标准化空间 1 1 中定义一个正方形 请注意 只有 x 和 y 被标准化 class Vec3 3D VECTOR de
  • 为什么在尝试从列表中删除元素时会收到 UnsupportedOperationException?

    我有这个代码 public static String SelectRandomFromTemplate String template int count String split template split List
  • Python Selenium onclick 抛出 ElementNotInteractableException

    在我想使用 Selenium 进行交互的网站上 有以下 html 代码部分 a href img src img rename png 1 alt change name title change name a 这显示了一个小图像 单击该图
  • 根据单元格值更改 Excel 中的弧长

    我想根据单元格值动态更改 Excel 中的弧长 例如 如果单元格值 100 则拱形应成为完整的圆形 如果该值 0 它应该消失 我发现下面的代码可以更改形状的大小 但我不知道如何修改它来更改长度 Example 非常感谢您的帮助 Privat
  • 通过变量之一的值设置堆积条形图的顺序

    我被要求制作一个堆叠条形图 其中的条形和值以精确的方式堆叠和排序 在本例中 A3 在左侧 A2 在中间 A1 在右侧 我已经解决了 我没有注意到的是 我还被要求按 A1 的值降序排列条形 在这种情况下 这意味着 值 11 出现在顶部 按降序
  • Runtime.exec() 的安全问题

    我正在使用 Runtime exec 来运行可执行文件 我一直在研究并发现在应用程序中使用它时可能存在安全问题 使用 Runtime exec 运行可执行文件时是否存在安全问题 Jeanne Boyarsky 显然你不能按照你提到的方式注入
  • 本地主机上的目录名(__FILE__)

    我正在使用 WAMP 并且在 www 目录中有一个开发站点 我想用dirname FILE 定义服务器根目录的路径 目前我正在使用一个配置文件 其中包含 define PATH dirname FILE 我将配置文件包含在我的header
  • Android Studio:错误文件名、目录名或卷标语法不正确

    我使用的是 Windows 7 64 位并切换到最新的 Android Studio 但收到此错误 错误 配置项目 myproject 时出现问题 无法标准化文件 C Users me Apps Android android myproj
  • 将 8 个布尔值转换为 1 个字节的最佳方法?

    我想将 8 个布尔值保存到一个字节 然后将其保存到一个文件中 这项工作必须针对非常大的数据完成 我使用了以下代码 但我不确定它是最好的代码 就术语而言 速度和空间 int bits 1 0 0 0 0 1 1 1 char a 0 for
  • UnicodeDecodeError:“ascii”编解码器无法解码位置 47 中的字节 0x92:序号不在范围内(128)

    我正在尝试使用 Python 在 StringIO 对象中写入数据 然后最终使用 psycopg2 的 copy from 函数将此数据加载到 postgres 数据库中 首先 当我这样做时 copy from 抛出错误 错误 编码 UTF
  • d3.js 中的转义字符

    我需要在图表的刻度格式中显示微摩尔每升 mol L 但是当我传入 mol L 时 它会显示字符 而不是 mu 的符号 我如何让它渲染符号 在这种情况下 您不应使用 HTML 实体 一旦你处理 SVG 请使用 u00B5 检查这个片段 var
  • 计算混淆矩阵的更快方法?

    我正在计算图像语义分割的混淆矩阵 如下所示 这是一种非常冗长的方法 def confusion matrix preds labels conf m sample size preds normalize preds 0 9 returns