Keras model.fit log 和 Sklearn.metrics.confusion_matrix 报告的验证准确性指标彼此不匹配

2024-04-07

问题是报道的validation accuracy我从 Keras 获得的价值model.fit历史显着高于validation accuracy我得到的指标sklearn.metrics功能。

我得到的结果model.fit总结如下:

Last Validation Accuracy: 0.81
Best Validation Accuracy: 0.84

结果(标准化)来自sklearn非常不同:

True Negatives: 0.78
True Positives: 0.77

Validation Accuracy = (TP + TN) / (TP + TN + FP + FN) = 0.775 

(see confusion matrix below for reference)

Edit: this calculation is incorrect, because one can not 
use the normalized values to calculate the accuracy, since 
it does not account for differences in the total absolute 
number of points in the dataset. Thanks to the comment by desertnaut
  • Here is the graph of the validation accuracy data from model.fit history: Validation accuracy from model.fit data history

  • 这是 sklearn 生成的混淆矩阵:

我觉得这个问题和this有点相似Sklearn 指标值与 Keras 值有很大不同 https://stackoverflow.com/questions/54580679/sklearn-metrics-values-are-very-different-from-keras-values但我已经检查过这两种方法都在同一数据池上进行验证,因此这个答案可能不适合我的情况。

还有这个问题Keras 二进制精度度量给出的精度太高 https://stackoverflow.com/questions/46354182/keras-binary-accuracy-metric-gives-too-high-accuracy似乎解决了二元交叉熵影响多类问题的方式的一些问题,但在我的情况下它可能不适用,因为它是一个真正的二元分类问题。

以下是使用的命令:

型号定义:

inputs = Input((Tx, ))
n_e = 30
embeddings = Embedding(n_x, n_e, input_length=Tx)(inputs)
out = Bidirectional(LSTM(32, recurrent_dropout=0.5, return_sequences=True))(embeddings)
out = Bidirectional(LSTM(16, recurrent_dropout=0.5, return_sequences=True))(out)
out = Bidirectional(LSTM(16, recurrent_dropout=0.5))(out)
out = Dense(3, activation='softmax')(out)
modelo = Model(inputs=inputs, outputs=out)
modelo.summary()

型号概要:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 100)               0         
_________________________________________________________________
embedding (Embedding)        (None, 100, 30)           86610     
_________________________________________________________________
bidirectional (Bidirectional (None, 100, 64)           16128     
_________________________________________________________________
bidirectional_1 (Bidirection (None, 100, 32)           10368     
_________________________________________________________________
bidirectional_2 (Bidirection (None, 32)                6272      
_________________________________________________________________
dense (Dense)                (None, 3)                 99        
=================================================================
Total params: 119,477
Trainable params: 119,477
Non-trainable params: 0
_________________________________________________________________

模型编译:

mymodel.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])

模型拟合调用:

num_epochs = 30
myhistory = mymodel.fit(X_pad, y, epochs=num_epochs, batch_size=50, validation_data=[X_val_pad, y_val_oh], shuffle=True, callbacks=callbacks_list)

模型拟合日志:

Train on 505 samples, validate on 127 samples

Epoch 1/30
500/505 [============================>.] - ETA: 0s - loss: 0.6135 - acc: 0.6667
[...]
Epoch 10/30
500/505 [============================>.] - ETA: 0s - loss: 0.1403 - acc: 0.9633
Epoch 00010: val_acc improved from 0.77953 to 0.79528, saving model to modelo-10-melhor-modelo.hdf5
505/505 [==============================] - 21s 41ms/sample - loss: 0.1393 - acc: 0.9637 - val_loss: 0.5203 - val_acc: 0.7953
Epoch 11/30
500/505 [============================>.] - ETA: 0s - loss: 0.0865 - acc: 0.9840
Epoch 00011: val_acc did not improve from 0.79528
505/505 [==============================] - 21s 41ms/sample - loss: 0.0860 - acc: 0.9842 - val_loss: 0.5257 - val_acc: 0.7953
Epoch 12/30
500/505 [============================>.] - ETA: 0s - loss: 0.0618 - acc: 0.9900
Epoch 00012: val_acc improved from 0.79528 to 0.81102, saving model to modelo-10-melhor-modelo.hdf5
505/505 [==============================] - 21s 42ms/sample - loss: 0.0615 - acc: 0.9901 - val_loss: 0.5472 - val_acc: 0.8110
Epoch 13/30
500/505 [============================>.] - ETA: 0s - loss: 0.0415 - acc: 0.9940
Epoch 00013: val_acc improved from 0.81102 to 0.82152, saving model to modelo-10-melhor-modelo.hdf5
505/505 [==============================] - 21s 42ms/sample - loss: 0.0413 - acc: 0.9941 - val_loss: 0.5853 - val_acc: 0.8215
Epoch 14/30
500/505 [============================>.] - ETA: 0s - loss: 0.0443 - acc: 0.9933
Epoch 00014: val_acc did not improve from 0.82152
505/505 [==============================] - 21s 42ms/sample - loss: 0.0453 - acc: 0.9921 - val_loss: 0.6043 - val_acc: 0.8136
Epoch 15/30
500/505 [============================>.] - ETA: 0s - loss: 0.0360 - acc: 0.9933
Epoch 00015: val_acc improved from 0.82152 to 0.84777, saving model to modelo-10-melhor-modelo.hdf5
505/505 [==============================] - 21s 42ms/sample - loss: 0.0359 - acc: 0.9934 - val_loss: 0.5663 - val_acc: 0.8478
[...]
Epoch 30/30
500/505 [============================>.] - ETA: 0s - loss: 0.0039 - acc: 1.0000
Epoch 00030: val_acc did not improve from 0.84777
505/505 [==============================] - 20s 41ms/sample - loss: 0.0039 - acc: 1.0000 - val_loss: 0.8340 - val_acc: 0.8110

sklearn 的混淆矩阵:

from sklearn.metrics import confusion_matrix
conf_mat = confusion_matrix(y_values, predicted_values)

预测值和黄金值确定如下:

preds = mymodel.predict(X_val)
preds_ints = [[el] for el in np.argmax(preds, axis=1)]
values_pred = tokenizer_y.sequences_to_texts(preds_ints)
values_gold = tokenizer_y.sequences_to_texts(y_val)

最后,我想补充一点,我已经打印出了数据和所有预测错误,并且我相信 sklearn 值更可靠,因为它们似乎与我打印出保存的“最佳”模型的预测得到的结果相匹配。

另一方面,我无法理解这些指标怎么会如此不同。由于它们都是众所周知的软件,因此我断定我是犯错误的人,但我无法确定错误的位置或方式。


你的问题不恰当;正如已经评论过的,您尚未计算 scikit-learn 模型的实际准确性,因此您似乎将苹果与橙子进行比较。归一化混淆矩阵的计算 (TP + TN)/2 确实not给出准确度。这是使用玩具数据的简单演示,调整plot_confusion_matrix来自docs https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

# toy data
y_true = [0, 1, 0, 1, 0, 0, 0, 1]
y_pred =  [1, 1, 1, 0, 1, 1, 0, 1]
class_names=[0,1]

# plot_confusion_matrix function

def plot_confusion_matrix(y_true, y_pred, classes,
                          normalize=False,
                          title=None,
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title = 'Confusion matrix, without normalization'

    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           # ... and label them with the respective list entries
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel='True label',
           xlabel='Predicted label')

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    return ax

计算归一化混淆矩阵给出:

plot_confusion_matrix(y_true, y_pred, classes=class_names, normalize=True)
# result:
Normalized confusion matrix
[[ 0.2         0.8       ]
 [ 0.33333333  0.66666667]]

并根据你的不正确根据原理,准确度应该是:

(0.67 + 0.2)/2
# 0.435

(注意在归一化矩阵中rows添加到 100%,这在完整的混淆矩阵中不会发生)

但现在让我们看看是什么real准确度来自于未标准化混淆矩阵:

plot_confusion_matrix(y_true, y_pred, classes=class_names) # normalize=False by default
# result
Confusion matrix, without normalization
[[1 4]
 [1 2]]

由此,根据精度的定义为 (TP + TN) / (TP + TN + FP + FN),我们得到:

(1+2)/(1+2+4+1)
# 0.375

当然,我们不需要混淆矩阵来获得像准确率这样基本的东西;正如评论中已经建议的那样,我们可以简单地使用内置的accuracy_scorescikit-learn的方法:

from sklearn.metrics import accuracy_score
accuracy_score(y_true, y_pred)
# 0.375

毫不奇怪,这与我们从混淆矩阵直接计算的结果一致。


底线:

  • 具体方法(例如accuracy_score)存在,使用它们绝对比临时灵感更好,尤其当某些事情看起来不正确时(例如 Keras 和 scikit-learn 报告的准确性之间存在差异)
  • 事实上,在此示例中,实际准确度低于您自己的方式计算的准确度,这一事实显然并没有说明您报告的具体问题
  • 如果即使计算出数据的正确准确性后与 Keras 的差异仍然存在,请执行以下操作not根据新情况更改问题,因为这会使答案无效,尽管它突出显示了方法中的错误点 - 请改为提出一个新问题
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Keras model.fit log 和 Sklearn.metrics.confusion_matrix 报告的验证准确性指标彼此不匹配 的相关文章

  • 如何有条件地组合两个相同形状的 numpy 数组

    这听起来很简单 但我想我把它想得太复杂了 我想创建一个数组 其元素是从两个形状相同的源数组生成的 具体取决于源数组中哪个元素更大 为了显示 import numpy as np array1 np array 2 3 0 array2 np
  • 蟒蛇 |如何将元素随机添加到列表中

    有没有一种方法可以将元素随机添加到列表中 内置函数 ex def random append lst a lst append b lst append c lst append d lst append e return print ls
  • 返回不包括指定键的字典副本

    我想创建一个函数 返回字典的副本 不包括列表中指定的键 考虑这本词典 my dict keyA 1 keyB 2 keyC 3 致电without keys my dict keyB keyC 应该返回 keyA 1 我想用一行简洁的字典理
  • 如何在 openpyxl 中设置或更改表格的默认高度

    我想通过openpyxl更改表格高度 并且我希望首先默认一个更大的高度值 然后我可以设置自动换行以使我的表格更漂亮 但我不知道如何更改默认高度 唯一的到目前为止 我知道更改表格高度的方法是设置 row dimension idx heigh
  • 在python中调用subprocess.Popen时“系统找不到指定的文件”

    我正在尝试使用svnmerge py合并一些文件 它在底层使用 python 当我使用它时 我收到一个错误 系统找不到指定的文件 工作中的同事正在运行相同版本的svnmerge py 以及 python 2 5 2 特别是 r252 609
  • 一起使用 Argparse 和 Json

    我是 Python 初学者 我想知道 Argparse 和 JSON 是否可以一起使用 说 我有变量p q r 我可以将它们添加到 argparse 中 parser add argument p param1 help x variabl
  • 使用 Python 解析 XML,解析外部 ENTITY 引用

    在我的 S1000D xml 中 它指定了一个带有对公共 URL 的引用的 DOCTYPE 该 URL 包含对包含所有有效字符实体的许多其他文件的引用 我使用 xml etree ElementTree 和 lxml 尝试解析它并得到解析错
  • 如何像在浏览器中一样检索准确的 HTML

    我正在使用 Python 脚本来呈现网页并检索其 HTML 它适用于大多数页面 但对于其中一些页面 检索到的 HTML 不完整 我不太明白为什么 这是我用来废弃此页面的脚本 由于某种原因 每个产品的链接不在 HTML 中 Link http
  • 如何使用注释和聚合在 Django 的 ORM 中执行此 GROUP BY 查询

    我真的不知道如何翻译GROUP BY and HAVING到姜戈的QuerySet annotate and QuerySet aggregate 我正在尝试将这个 SQL 查询转换为 ORM 语言 SELECT EXTRACT year
  • Matplotlib 将颜色图 tab20 更改为三种颜色

    Matplotlib 有一些新的且非常方便的颜色图 选项卡颜色图 https matplotlib org examples color colormaps reference html 我错过的是生成像 tab20b 或 tab20c 这
  • scikit-learn RandomForestClassifier 中的子样本大小

    如何控制用于训练森林中每棵树的子样本的大小 根据 scikit learn 的文档 随机森林是一种适合许多决策的元估计器 数据集的各个子样本上的树分类器并使用 平均以提高预测准确性并控制过度拟合 子样本大小始终与原始输入样本相同 大小 但如
  • 与函数复合 UniqueConstraint

    一个快速的 SQLAlchemy 问题 我有一个 文档 类 其属性为 数字 和 日期 我需要确保没有重复的号码同年 是 有没有办法对 数字 年份 日期 进行UniqueConstraint 我应该使用唯一索引吗 我如何声明功能部分 SQLA
  • Keras 中的损失函数和度量有什么区别? [复制]

    这个问题在这里已经有答案了 我不清楚 Keras 中损失函数和指标之间的区别 该文档对我没有帮助 损失函数用于优化您的模型 这是优化器将最小化的函数 指标用于判断模型的性能 这仅供您查看 与优化过程无关
  • Python:如何从文件中的一行读取字符并将它们转换为浮点数和字符串,具体取决于它们是数字还是字母?

    我有一个如下所示的文件 1 1 C C 1 9873 2 347 3 88776 1 2 C Si 4 887 9 009 1 21 我想逐行读取文件的内容 当我使用的行上只有数字时 for line in readlines file d
  • Scikit Learn - K-Means - 肘部 - 标准

    今天我想学习一些关于 K means 的知识 我已经了解该算法并且知道它是如何工作的 现在我正在寻找正确的 k 我发现肘部准则作为检测正确的 k 的方法 但我不明白如何将它与 scikit learn 一起使用 在 scikit learn
  • dask allocate() 或 apply() 中的变量列名

    我有适用于pandas 但我在将其转换为使用时遇到问题dask 有一个部分解决方案here https stackoverflow com questions 32363114 how do i change rows and column
  • LSTM 批次与时间步

    我按照 TensorFlow RNN 教程创建了 LSTM 模型 然而 在这个过程中 我对 批次 和 时间步长 之间的差异 如果有的话 感到困惑 并且我希望得到帮助来澄清这个问题 教程代码 见下文 本质上是根据指定数量的步骤创建 批次 wi
  • 字母尺度和随机文本上的马尔可夫链

    我想使用 txt 文件中的一本书中的字母频率生成随机文本 以便每个新字符 string lowercase 取决于前一个 如何使用马尔可夫链来做到这一点 或者使用每个字母都有条件频率的 27 个数组更简单 我想使用来自的字母频率生成随机文本
  • 如何使用 FastAPI 在 HTMX 前端中使用 HX-Redirect?

    我试图在登录后在前端重定向 我像这样从我的 htmx 前端发出请求
  • 将数组从 .npy 文件读入 Fortran 90

    我使用 Python 以二维数组 例如 X 的形式生成一些初始数据 然后使用 Fortran 对它们进行一些计算 最初 当数组大小约为 10 000 x 10 000 时 np savetxt 在速度方面表现良好 但是一旦我开始增加数组的维

随机推荐

  • iOS 中的本地通知没有任何声音

    void notifyMe UILocalNotification localNotification UILocalNotification alloc init localNotification fireDate NSDate dat
  • Git:父提交比后代提交年轻?

    我正在浏览http arago project org git projects linux omap3 git http arago project org git projects linux omap3 gitrepo 并遇到了一个奇
  • MASM0015; Web服务HandlerTubeFactory异常

    我正在尝试创建一个可以调用另一个的网络服务链 我已经创建了第一个服务并成功部署在 weblogic12c 上 当第一个 Web 服务尝试调用外部 Web 服务时 出现以下异常 notifyAbout WSTestOuter outer ne
  • 使用 javascript 加载部分 html

    在我的网站上 我加载在服务器 nodejs 上呈现的 html 并将其插入到正确的位置 大多数情况下是带有 id 内容的 div 如何在客户端插入接收到的 html 以便执行包含的脚本标记 我在客户端使用下划线和把手 但 vanillajs
  • 为什么GK110有192个核心和4个扭曲?

    我想感受一下开普勒的架构 但这对我来说没有意义 如果一个 warp 有 32 个线程 其中 4 个被调度 执行 则意味着 128 个核心正在使用 64 个核心处于空闲状态 白皮书中提到了独立指令 那么64核是为这些指令保留的吗 如果是这样
  • 如何从此类图像中删除背景?

    我想删除该图像的背景以仅获取人物 我有数千张这样的图像 基本上是一个人和一个有点发白的背景 我所做的是使用边缘检测器 例如 canny 边缘检测器或索贝尔滤波器 来自skimage图书馆 然后我认为可以做的是 将边缘内的像素变白 并将边缘外
  • 如何在 React + Babel 中允许异步函数?

    我有一个 Typescript React 应用程序 它可以使用 then catch Promise 执行异步函数 但不能使用 async await try catch 执行异步函数 错误是 Uncaught ReferenceErro
  • 使用 Visual Studio 查找 C++ 应用程序中的内存泄漏

    在Linux中 我一直使用valgrind来检查应用程序中是否存在内存泄漏 Windows 中的等效项是什么 这可以用 Visual Studio 2010 来完成吗 Visual Studio 2019 有一个不错的内存分析工具 它可以在
  • jpql“加入获取”与 EntityGraph

    我想使用 jpql 或 jpa 实体图加载相关实体 看起来两者都做同样的事情 我为什么要使用实体图而不是普通的jpql 有什么好处吗 使用jpql有什么区别 select distinct u from User u join fetch
  • 无法从“node_modules\react-native-gesture-handler\createHandler.js”解析“fbjs/lib/areEqual”

    我正在使用 expo 构建一个反应本机应用程序 但是 我有一个错误 因此我无法继续构建该应用程序 我什至在错误消息中提到的node modules 中查找了文件 我正在使用 React native gesture handler 进行屏幕
  • 快速引导大量分层数据的策略以及在任何记录发生更改时更新 Elasticsearch 中的单个分层 json 文档的方法

    根据业务场景 来自 2 个关系表 最好是多个表 例如 6 7 的列必须合并到单个分层 json 文档中 以用于 Elasticsearch 上的单个索引 如下面示例文档中所述 样本文件 员工及联系信息 id 1 name tom john
  • 如何在 TypeORM 查找选项中设置 IS NULL 条件?

    在我的查询中我使用 TypeORMfind选项 我怎样才能拥有IS NULL条件在where clause 如果有人正在寻找 NOT NULL 它会是这样的 import IsNull Not from typeorm return awa
  • AWS CloudWatch 未使用的自定义指标保留和定价 - 2018 年

    如果我理解正确的话 自定义指标似乎将保留 15 个月 因为根据数据 它们会聚合为更高分辨率https aws amazon com cloudwatch faqs https aws amazon com cloudwatch faqs 这
  • 正则表达式匹配未完成

    我曾经有过一次回答了一个问题 https stackoverflow com a 17723854 882200关于将带引号的字符串与转义引号匹配 似乎有些情况会在 NET 上挂起并在 Mono 上崩溃 带有OutOfMemoryExcep
  • 使用描述符进行类型提示

    In 这个拉取请求 https github com python mypy pull 2266看起来添加了对描述符的类型提示支持 然而 似乎没有发布最终的 正确 用法示例 也没有添加任何文档到typing module https doc
  • C# htmlagilitypack,捕获重定向

    大家好 这真的很简单 我希望 我正在使用 htmlagility pack 进行网络爬虫 那么 如果我输入 url 然后将我定向到新的 url 会发生什么情况 如何捕获该新的重定向 URL 如果 htmlagilitypack 没有办法 有
  • 登录 GCP 和本地

    我正在构建一个旨在在 Google Cloud Platform 中的虚拟机上运行的系统 但是 作为一种备份形式 它也可以在本地运行 话虽这么说 我目前的问题是日志记录 我有两个记录器 都可以工作 一个本地记录器和一个云记录器 云记录器 i
  • 在 FTP 上上传文件

    我想将文件从一台服务器上传到另一台 FTP 服务器 以下是我上传文件的代码 但它抛出错误 远程服务器返回错误 550 文件不可用 例如 未找到文件 无法访问 这是我的代码 string CompleteDPath ftp URL strin
  • 使用贝叶斯优化的深度学习结构的超参数优化

    我为原始信号分类任务构建了 CLDNN 卷积 LSTM 深度神经网络 结构 每个训练周期运行约 90 秒 超参数似乎很难优化 我一直在研究优化超参数的各种方法 例如随机或网格搜索 并发现了贝叶斯优化 虽然我还没有完全理解优化算法 但我认为它
  • Keras model.fit log 和 Sklearn.metrics.confusion_matrix 报告的验证准确性指标彼此不匹配

    问题是报道的validation accuracy我从 Keras 获得的价值model fit历史显着高于validation accuracy我得到的指标sklearn metrics功能 我得到的结果model fit总结如下 Las