基于tensorflow的流指标的自定义指标返回NaN

2024-01-09

我正在尝试将 F1 分数定义为 TensorFlow 中的自定义指标DNNClassifier。为此,我编写了一个函数

def metric_fn(predictions=[], labels=[], weights=[]):
    P, _ = tf.contrib.metrics.streaming_precision(predictions, labels)
    R, _ = tf.contrib.metrics.streaming_recall(predictions, labels)
    if P + R == 0:
        return 0
    return 2*(P*R)/(P+R)

使用streaming_precision and streaming_recall从 TensorFlow 计算 F1 分数。之后我在验证指标中添加了一个新条目:

validation_metrics = {
    "accuracy":
        tf.contrib.learn.MetricSpec(
            metric_fn=tf.contrib.metrics.streaming_accuracy,
            prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
    "precision":
        tf.contrib.learn.MetricSpec(
            metric_fn=tf.contrib.metrics.streaming_precision,
            prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
    "recall":
        tf.contrib.learn.MetricSpec(
            metric_fn=tf.contrib.metrics.streaming_recall,
            prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
    "f1score":
        tf.contrib.learn.MetricSpec(
            metric_fn=metric_fn,
            prediction_key=tf.contrib.learn.PredictionKey.CLASSES)
}

然而,尽管我得到了正确的精度和召回值,f1score总是nan:

INFO:tensorflow:Saving dict for global step 151: accuracy = 0.982456, accuracy/baseline_label_mean = 0.397661, accuracy/threshold_0.500000_mean = 0.982456, auc = 0.982867, f1score = nan, global_step = 151, labels/actual_label_mean = 0.397661, labels/prediction_mean = 0.406118, loss = 0.310612, precision = 0.971014, precision/positive_threshold_0.500000_mean = 0.971014, recall = 0.985294, recall/positive_threshold_0.500000_mean = 0.985294

我的有问题metric_fn,但我无法弄清楚。 价值P and R获得于metric_fn具有以下形式Tensor("precision/value:0", shape=(), dtype=float32)。我觉得这有点奇怪。我期待一个标量张量。

任何帮助表示赞赏。


我认为问题可能来自于您在您的应用程序中使用的流媒体指标这一事实metric_fn没有得到任何更新。

尝试以下方法(我还根据自己的口味进行了一些细微的修改):

def metric_fn(predictions=None, labels=None, weights=None):
    P, update_op1 = tf.contrib.metrics.streaming_precision(predictions, labels)
    R, update_op2 = tf.contrib.metrics.streaming_recall(predictions, labels)
    eps = 1e-5;
    return (2*(P*R)/(P+R+eps), tf.group(update_op1, update_op2))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于tensorflow的流指标的自定义指标返回NaN 的相关文章

  • tensorflow Protobuf编译问题

    我想为 google 对象检测 API 编译 protobuf 库 我按照官方教程输入protoc object detection protos proto python out 然后我得到的是 object detection prot
  • CVXPY 二次规划; ArpackNoConvergence 错误

    我尝试使用 Python 包 CVXPY 来解决第一种形式的凸二次规划问题 https www cvxpy org examples basic quadratic program html https www cvxpy org exam
  • 如何将本机 popcount 与 numba 一起使用

    我正在使用 numba 0 57 1 我想在我的代码中利用本机 CPU popcount 我现有的代码太慢 因为我需要运行它数亿次 这是一个 MWE import numba as nb nb njit nb uint64 nb uint6
  • 重新索引错误没有意义

    I have DataFrames大小在 100k 到 2m 之间 我正在处理这个问题的框架是如此之大 但请注意 我必须对其他框架执行相同的操作 gt gt gt len data 357451 现在这个文件是通过编译许多文件创建的 所以它
  • scipy.misc.imshow RuntimeError('无法执行图像视图')

    我正在测试scipy misc imshow https docs scipy org doc scipy 0 15 1 reference generated scipy misc imshow html我得到了运行时错误 无法执行图像查
  • 如何移动我的图像? python 3.10.4 pygame

    我会移动我的图像 图像是matiskinfinal png 我尝试将像素添加到 x 或其他我不知道它是什么的东西 因为我真的是 python 的初学者 pygame但是是 x x 变化 但图像没有移动 import os import py
  • Python 按照层次结构按多个分隔符分割字符串

    我只想根据多个分隔符 例如 and 和 按顺序分割字符串一次 例子 121 34 adsfd gt 121 34 adsfd dsfsd and adfd gt dsfsd adfd dsfsd adfd gt dsfsd adfd dsf
  • 如何从 Lua 调用 Python 函数?

    我想从我的 lua 文件运行 python 脚本 我怎样才能实现这个目标 Example Python代码 sum py file def sum from python a b return a b Lua code main lua f
  • 使用 boto3 从 s3 下载时使用 filename 作为文件名

    我正在使用 boto3 上传文件 如下所示 client boto3 client s3 aws access key id id aws secret access key key client upload file tmp test
  • Flask 中的 import 和 extends 有什么区别?

    我正在阅读 Flask Web 开发 在例4 3中 extends base html import bootstrap wtf html as wtf 我想知道 extends 和 import 有什么区别 我认为它们在用法上很相似 在什
  • python-polars 通过分隔符将字符串列拆分为许多列

    在 pandas 中 以下代码会将 col1 中的字符串拆分为许多列 有没有办法在极地做到这一点 d col1 a b c d a b c d df pd DataFrame data d df a b c d df col1 str sp
  • 是否有更矢量化的方法来沿轴执行 numpy.outer ?

    gt gt gt x np array a0 a1 b0 b1 gt gt gt y np array x0 x1 y0 y1 gt gt gt iterable np outer x i y i for i in xrange x sha
  • 如何为 Python 中的应用程序设置专用屏幕区域?

    MS OneNote 就是一个很好的例子 它可以选择固定在屏幕的一侧 并将所有其他窗口推到一侧 当最大化或调整其他窗口大小时 它们只能扩展到 OneNote 的边缘 Python 使用 Tkinter 或其他模块是否具有此功能 感谢您的帮助
  • PyCharm 无法识别字典值类型

    我有一个简单的代码片段 其中我将字典值设置为空列表 new dict for i in range 1 13 new dict i 现在 如果在下一行的循环内我会输入new dict i 并添加一个点 我希望 PyCharm 向我显示可用于
  • pandas to_sql sqlalchemy 与 secure_transport 的连接

    我正在尝试将数据发送到具有 require secure transport ON 的服务器上的 mysql 数据库 当我尝试使用以下代码连接到它时 import pandas as pd import pymysql from sqlal
  • 使 np.loadtxt 使用多个可能的分隔符

    我有一个程序可以读取数据文件 用户可以选择他们想要使用的列 我希望它对于输入文件更加通用 有时 列可能如下所示 10 34 24 58 8 284 6 121 有时它们可 能看起来像这样 10 34 24 58 8 284 6 121 我希
  • 如何在 Pytorch 中将一维 IntTensor 转换为 int

    如何将一维 IntTensor 转换为整数 这 IntTensor int 给出错误 KeyError Variable containing 423 torch IntTensor of size 1 我所知道的最简单 最干净的方法 In
  • 如何测试列表中多个值的成员资格

    我想测试两个或多个值是否在列表中具有成员资格 但我得到了意外的结果 gt gt gt a b in b a foo bar a True 那么 Python 可以同时测试列表中多个值的成员资格吗 这个结果意味着什么 See also How
  • Scrapy 抓取并跟踪 href 中的链接

    我对 scrapy 很陌生 我需要从 url 的主页跟踪 href 到多个深度 再次在 href 链接内我有多个 href 我需要遵循这些href 直到到达我想要抓取的页面 我的页面的示例 html 是 初始页 div class page
  • 如何保持 python 3 脚本 (Bot) 运行

    不是母语英语 抱歉 英语可能很蹩脚 我也是编程新手 您好 我正在尝试使用 QueryServer 连接到 TeamSpeak 服务器来创建机器人 经过几天的努力 它有效 只有 1 个问题 而我却被这个问题困扰了 如果您需要检查 这是我正在使

随机推荐