使用sklearn宏f1-score作为tensorflow.keras中的指标

2024-04-27

我已经为tensorflow.keras定义了自定义指标,以在每个时期之后计算宏f1分数,如下所示:

from tensorflow import argmax as tf_argmax
from sklearn.metric import f1_score

def macro_f1(y_true, y_pred):
    # labels are one-hot encoded. so, need to convert
    # [1,0,0] to 0 and
    # [0,1,0] to 1 and
    # [0,0,1] to 2. Then pass these arrays to sklearn f1_score.
    y_true = tf_argmax(y_true, axis=1)
    y_pred = tf_argmax(y_pred, axis=1)
    return f1_score(y_true, y_pred, average='macro')

并在模型编译期间使用它

model_4.compile(loss = 'categorical_crossentropy',
                optimizer = Adam(lr=init_lr, decay=init_lr / num_epochs),
                metrics = [Recall(name='recall') #, weighted_f1
                           macro_f1])

当我尝试像这样适应时:

history_model_4 = model_4.fit(train_image_generator.flow(x=train_imgs, y=train_targets, batch_size=batch_size),
                            validation_data = (val_imgs, val_targets),
                            epochs=num_epochs,
                            class_weight=mask_weights_train,
                            callbacks=[model_save_cb, early_stop_cb, epoch_times_cb],
                            verbose=2)

这是错误:

OperatorNotAllowedInGraphError: in user code:

    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:806 train_function  *
        return step_function(self, iterator)
    <ipython-input-57-a890ea61878e>:6 macro_f1  *
        return f1_score(y_true, y_pred, average='macro')
    /usr/local/lib/python3.6/dist-packages/sklearn/metrics/_classification.py:1095 f1_score  *
        return fbeta_score(y_true, y_pred, 1, labels=labels,
    /usr/local/lib/python3.6/dist-packages/sklearn/metrics/_classification.py:1217 fbeta_score  *
        _, _, f, _ = precision_recall_fscore_support(y_true, y_pred,
    /usr/local/lib/python3.6/dist-packages/sklearn/metrics/_classification.py:1478 precision_recall_fscore_support  *
        labels = _check_set_wise_labels(y_true, y_pred, average, labels,
    /usr/local/lib/python3.6/dist-packages/sklearn/metrics/_classification.py:1301 _check_set_wise_labels  *
        y_type, y_true, y_pred = _check_targets(y_true, y_pred)
    /usr/local/lib/python3.6/dist-packages/sklearn/metrics/_classification.py:80 _check_targets  *
        check_consistent_length(y_true, y_pred)
    /usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py:209 check_consistent_length  *
        uniques = np.unique(lengths)
    <__array_function__ internals>:6 unique  **
        
    /usr/local/lib/python3.6/dist-packages/numpy/lib/arraysetops.py:263 unique
        ret = _unique1d(ar, return_index, return_inverse, return_counts)
    /usr/local/lib/python3.6/dist-packages/numpy/lib/arraysetops.py:311 _unique1d
        ar.sort()
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:877 __bool__
        self._disallow_bool_casting()
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:487 _disallow_bool_casting
        "using a `tf.Tensor` as a Python `bool`")
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:474 _disallow_when_autograph_enabled
        " indicate you are trying to use an unsupported feature.".format(task))

    OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

是什么导致了此类错误,如何修复它并将其用作每个 y 纪元结束时的评估指标之一?

EDIT 1:
注意:所有这些都是在 jupyter 笔记本中完成的,我在单独的行中添加了“>>>”

# getting a batch to pass to model
>>> a_batch = train_image_generator.flow(x=train_imgs, y=train_targets, batch_size=batch_size).next()
# checking its' type to ensure that it's what i though it is
>>> type(a_batch)
# passing the batch to the model
>>> logits = model_4(a_batch)
# checking the type of output
>>> type(logits)
tensorflow.python.framework.ops.EagerTensor
# extracting only the passed targets to calculate f1-score
>>> _, dummy_targets = a_batch
# checking it's type
>>> type(dummy_targets)
numpy.ndarray
>>> macro_f1(y_true=dummy_targets, y_pred=logits)
0.0811965811965812

sklearn不是 TensorFlow 代码 - 始终建议避免在 TF 中使用在 TF 执行图中执行的任意 Python 代码。

TensorFlow 插件 https://www.tensorflow.org/addons已经实现了 F1 分数 (tfa.metrics.F1Score https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score),因此更改您的代码以使用它而不是您的自定义指标

确保你pip install tensorflow-addons首先然后

import tensorflow_addons as tfa

model_4.compile(loss = 'categorical_crossentropy',
                optimizer = Adam(lr=init_lr, decay=init_lr / num_epochs),
                metrics = [Recall(name='recall') #, weighted_f1
                           tfa.metrics.F1Score(average='macro')])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用sklearn宏f1-score作为tensorflow.keras中的指标 的相关文章

随机推荐

  • 如何对表中的每一行运行特定的sql查询?

    所以我的数据库中有两个表 它们看起来都是这样的 通讯 拨打电话 Timestamp FromIDNumber ToIDNumber GeneralLocation 2012 03 02 09 02 30 878 674 Grasslands
  • 成员函数什么时候应该有 const 限定符,什么时候不应该有?

    大约六年前 一位名叫 Harri Porten 的软件工程师写道本文 http www froglogic com porten const html 提出这样的问题 成员函数什么时候应该有 const 限定符 什么时候不应该有 我发现这是
  • 查找最近的城市,例如 oodle.com

    因此 我正在尝试开发一个显示用户列表的应用程序 该网站应该检测用户位置 我为此使用 maxmind api 然后显示用户位置 用户指定半径内的城市的列表 我该怎么做呢 MaxMind API 让我可以通过 IP 地址检测用户的城市 但如何找
  • dprintf 与 break + 命令 + continue 之间有什么区别?

    例如 dprintf main hello n run 生成与以下内容相同的输出 break main commands silent printf hello n continue end run 使用是否有显着的优势dprintf ov
  • C# 中的嵌入字体

    我已经尝试了很多在 c 中的 wpf 应用程序中嵌入字体的方法 该字体的名称是 Roboto 文件名是机器人 ttf如果那有用的话 我已确保它已在程序集中编译 那么如何在a中应用字体TextBlock例如 您可以将字体应用到如下元素中
  • 在插入模式下移至行首

    我知道我可以使用 Home in insert mode Esc i to exit insert mode and enter it again effectively going to the beginning of line But
  • 如何使用 devise_invitable 发送自定义邀请

    我是 ruby 新手 使用 devise invitable gem 进行邀请 每条指令都正确发送 现在我想添加一个自定义主题 该主题将具有受邀者姓名和董事会名称以及与主题相同的自定义内容 我如何在以下操作方法中执行此操作而不使用额外的自定
  • git log --oneline --graph 输出的含义

    我正在学习相对提交引用并尝试理解以下内容git log oneline graph课程中提供的输出 在课程中它说给定的 HEAD 指向9ec05ca提交 HEAD 意思是曾祖父母提交 是0c5975a犯罪 但在我看来4c9749e如果每个
  • 如何将H2数据库文件存储到项目目录中

    当我使用H2数据库时 数据库文件存储在C Users MyName TestDataBase db目录 H2路径是jdbc h2 TestDataBase 这是默认的 H2 数据库路径 是否有可能像这样将 H2 数据库文件存储到我的项目目录
  • 为什么要使用继承? [关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心 help reopen questi
  • 如何引用所有正在运行的 Excel 应用程序实例(包括隐藏的和没有工作簿的实例)的 COM 对象?

    如何获取每个正在运行的 Excel 应用程序实例的完整引用列表 无论其工作簿数量和可见性状态如何 我知道我可以使用 Windows API 来查找每个 Excel 工作簿窗口 其窗口类名称EXCEL7 让他们的句柄与AccessibleOb
  • 快速崩溃“EXC_BREAKPOINT 0x0000000...”

    我的 iOS 应用程序确实发生了三种不同的崩溃情况 不同的代码位置 但所有三个都带有 exc breakpoint 0x000000 我无法重现它们 它们发生在不同的设备和不同的 iOS 版本上 如前所述 我无法重现它们 我们的测试人员都没
  • Android - API 级别 21 中的日期 [已关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 你好 我对 Android 还很陌生 目前我对本地日期 或我尝试过的任何其他日期格式 有一个大问题 L
  • Bootstrap 3 无法在 Symfony3 中工作

    我刚刚开始学习 Symfony 3 我正在尝试使用 bootstrap 3 为我的表单设置主题 根据文档 http symfony com doc current cookbook form form customization html
  • Python 交互式 Shell 类型应用程序

    我想创建一个交互式 shell 类型的应用程序 例如 gt app py Enter a command to do something eg create name price For to get help enter help wit
  • 如何识别导航堆栈中的先前视图控制器

    我有2个独立的navigationcontrollers 一与RootViewControllerA 和另一个RootViewController B 我有能力推动ViewControllerC 到 A 或 B 的导航堆栈上 问题 当我在V
  • 如何保护swf文件不被反编译?

    我正在使用 Flex 框架从事重要项目 我想对我的算法和代码保密 是否有可能以某种方式保护 swf 文件不被反编译 我不希望有人使用 flash 反编译器提取我的代码 Thanks 这很简单 只需将其保存在您的 PC 上 不要将其放在网络上
  • SQL Server 获取父列表

    我有一个这样的表 id name parent id 1 ab1 3 2 ab2 5 3 ab3 2 4 ab4 null 5 ab5 null 6 ab6 null 我需要使用输入 id 1 进行查询 例如 结果将如下所示 id name
  • 在 Java 1.7.0 下运行的 SQL-Server (MSSQL-JDBC 3.0) 中的日期列检索为过去 2 天

    当使用 SQLServer2008 从 SQLServer2008 检索 DATE 类型的列时 出现奇怪的效果在官方 Oracle JDK 1 7 0 下运行时 主机操作系统是Windows Server 2003 所有日期列均检索为two
  • 使用sklearn宏f1-score作为tensorflow.keras中的指标

    我已经为tensorflow keras定义了自定义指标 以在每个时期之后计算宏f1分数 如下所示 from tensorflow import argmax as tf argmax from sklearn metric import