tf.constant 和 tf.placeholder 的行为不同

2023-12-11

我想将 tf.metrics 包装在 Sonnet 模块中以测量每个批次的性能,以下是我所做的工作:

import tensorflow as tf
import sonnet as snt

class Metrics(snt.AbstractModule):
    def __init__(self, indicator, summaries = None, name = "metrics"):
        super(Metrics, self).__init__(name = name)
        self._indicator = indicator
        self._summaries = summaries

    def _build(self, labels, logits):
        if self._indicator == "accuracy":
            metric, metric_update = tf.metrics.accuracy(labels, logits)
            with tf.control_dependencies([metric_update]):
                outputs = tf.identity(metric)
        elif self._indicator == "precision":
            metric, metric_update = tf.metrics.precision(labels, logits)
            with tf.control_dependencies([metric_update]):
                outputs = tf.identity(metric)
        elif self._indicator == "recall":
            metric, metric_update = tf.metrics.recall(labels, logits)
            with tf.control_dependencies([metric_update]):
                outputs = tf.identity(metric)
        elif self._indicator == "f1_score":
            metric_recall, metric_update_recall = tf.metrics.recall(labels, logits)
            metric_precision, metric_update_precision = tf.metrics.precision(labels, logits)
            with tf.control_dependencies([metric_update_recall, metric_update_precision]):
                outputs = 2.0 / (1.0 / metric_recall + 1.0 / metric_precision)
        else:
            raise ValueError("unsupported metrics")

        if type(self._summaries) == list:
            self._summaries.append(tf.summary.scalar(self._indicator, outputs))

        return outputs

但是,当我想测试该模块时,以下代码有效:

def test3():
    import numpy as np

    labels = tf.constant([1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], tf.int32)
    logits = tf.constant([1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], tf.int32)

    metrics = Metrics("accuracy")
    accuracy = metrics(labels, logits)

    metrics2 = Metrics("f1_score")
    f1_score = metrics2(labels, logits)

    writer = tf.summary.FileWriter("utils-const", tf.get_default_graph())
    with tf.Session() as sess:
        sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

        accu, f1 = sess.run([accuracy, f1_score])
        print(accu)
        print(f1)

    writer.close()

但是以下代码不起作用:

def test4():
    from tensorflow.python import debug as tf_debug
    import numpy as np

    tf_labels = tf.placeholder(dtype=tf.int32, shape=[None])
    tf_logits = tf.placeholder(dtype=tf.int32, shape=[None])

    labels = np.array([1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], np.int32)
    logits = np.array([1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], np.int32)

    metrics = Metrics("accuracy")
    accuracy = metrics(tf_labels, tf_logits)

    metrics2 = Metrics("f1_score")
    f1_score = metrics2(tf_labels, tf_logits)

    writer = tf.summary.FileWriter("utils-feed", tf.get_default_graph())
    with tf.Session() as sess:
        sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

        sess = tf_debug.LocalCLIDebugWrapperSession(sess)

        accu, f1 = sess.run([accuracy, f1_score], feed_dict = {tf_labels: labels, tf_logits: logits})
        print(accu)
        print(f1)

    writer.close()

test3() 的输出是正确的,0.88。 test4() 的输出是错误的,0.0。然而,它们应该是等价的。

有人有什么想法吗?


你确定这不是tf.constant失败的版本?我发现tf.metrics结合起来有奇怪的行为tf.constant:

import tensorflow as tf

a = tf.constant(1.)
mean_a, mean_a_uop = tf.metrics.mean(a)
with tf.control_dependencies([mean_a_uop]):
  mean_a = tf.identity(mean_a)

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
tf.local_variables_initializer().run()

for _ in range(10):
  print(sess.run(mean_a))

返回,当在 GPU 上运行时,

0.0
2.0
1.5
1.3333334
1.25
1.2
1.1666666
1.1428572
1.125
1.1111112

代替1s。看起来计数落后了一位。 (我假设第一个值是inf但由于某些条件而为零count)。另一方面,此代码的占位符版本正在按预期运行。

在 CPU 上,行为更加奇怪,因为输出是不确定的。输出示例:

0.0
1.0
1.0
0.75
1.0
1.0
0.85714287
0.875
1.0
0.9

看起来像是一个错误,您可以登录张量流的 github 存储库。 (请注意,对常量使用运行指标不太有用——但这仍然是一个错误)。

EDIT现在我也偶然发现了奇怪的例子tf.placeholder, 看起来tf.metrics不幸的是,有一个错误不仅限于它的使用tf.constants.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

tf.constant 和 tf.placeholder 的行为不同 的相关文章

  • Virtualenv 在 OS X Yosemite 上失败并出现 OSError

    我最近更新到 OSX Yosemite 现在无法使用virtualenv pip 每当我执行 virtualenv env 它抛出一个 OSError Command Users administrator ux env bin pytho
  • 多处理中的动态池大小?

    有没有办法动态调整multiprocessing Pool尺寸 我正在编写一个简单的服务器进程 它会产生工作人员来处理新任务 使用multiprocessing Process对于这种情况可能更适合 因为工作人员的数量不应该是固定的 但我需
  • 按边距(“全部”)值列对 Pandas 数据透视表进行排序

    我试图根据 pandas 数据透视表中的行总和对最后一列 边距 aggrfunc 进行降序排序 我知道我在这里错过了一些简单的东西 但我无法弄清楚 数据框 数据透视表 WIDGETS DATE 2 1 16 2 2 16 2 3 16 Al
  • 是否可以从 Julia 调用 Python 函数并返回其结果?

    我正在使用 Python 从网络上抓取数据 我想使用这些数据在 Julia 中运行计算 是否可以在 Julia 中调用该函数并返回其结果 或者我最好直接导出到 CSV 并以这种方式加载数据 绝对地 看PyCall jl https gith
  • 创建上下文后将 jar 文件添加到 pyspark

    我正在笔记本上使用 pyspark 并且不处理 SparkSession 的创建 我需要加载一个包含一些我想在处理 rdd 时使用的函数的 jar 您可以使用 jars 轻松完成此操作 但在我的特定情况下我无法做到这一点 有没有办法访问sp
  • AttributeError:“模块”对象没有属性[重复]

    这个问题在这里已经有答案了 我有两个 python 模块 a py import b def hello print hello print a py print hello print b hi b py import a def hi
  • Python将文本文件解析为嵌套字典

    考虑以下数据结构 HEADER1 key value key value HEADER2 key value key value HEADER3 key value HEADER4 key value key value 原始数据中没有缩进
  • 根据其他单元格值更改多个单元格值

    我想更改包含的单元格moving to movingToOpenor movingToClose基于下一个单元格中给出的状态 有时循环会被中断并且不会从open to close or close to open 这是我当前的数据框 Dat
  • 两个不同长度的数据帧的列之间的余弦相似度?

    我在 df1 中有文本列 在 df2 中有文本列 df2 的长度将与 df1 的长度不同 我想计算 df1 text 中每个条目与 df2 text 中每个条目的余弦相似度 并为每场比赛给出分数 输入样本 df1 mahesh suresh
  • python中basestring和types.StringType之间的区别?

    有什么区别 isinstance foo types StringType and isinstance foo basestring 对于Python2 basestring是两者的基类str and unicode while type
  • 查找 Pandas DF 行中的最短日期并创建新列

    我有一个包含多个日期的表 有些日期将为 NaN 我需要找到最旧的日期 所以一行可能有 DATE MODIFIED WITHDRAWN DATE SOLD DATE STATUS DATE 等 因此 对于每一行 一个或多个字段中都会有一个日期
  • pandas 相当于 np.where

    np where具有向量化 if else 的语义 类似于 Apache Spark 的when otherwise数据帧方法 我知道我可以使用np where on pandas Series but pandas通常定义自己的 API
  • 可以使用哪些技术来衡量 pandas/numpy 解决方案的性能

    Question 如何简洁全面地衡量下面各个功能的性能 Example 考虑数据框df df pd DataFrame Group list QLCKPXNLNTIXAWYMWACA Value 29 52 71 51 45 76 68 6
  • 如何在亚马逊 EC2 上调试 python 网站?

    我是网络开发新手 这可能是一个愚蠢的问题 但我找不到可以帮助我的确切答案或教程 我工作的公司的网站 用 python django 构建 托管在亚马逊 EC2 上 我想知道从哪里开始调试这个生产站点并检查存储在那里的日志和数据库 我有帐户信
  • AWS Lambda 不读取环境变量

    我正在编写一个 python 脚本来查询 Qualys API 中的漏洞元数据 我在 AWS 中将其作为 lambda 函数执行 我已经在控制台中设置了环境变量 但是当我执行函数时 出现以下错误 module initialization
  • 如何编写一个接受 int 或 float 的 C 函数?

    我想用 C 语言创建一个扩展 Python 的函数 该函数可以接受 float 或 int 类型的输入 所以基本上 我想要f 5 and f 5 5 成为可接受的输入 我认为我不能使用if PyArg ParseTuple args i v
  • 如何获取pandas中groupby对象中的组数?

    我想知道有多少个独特的组需要执行计算 给定一个名为 groupby 的对象dfgroup 我们如何找到组的数量 简单 快速 Pandaic ngroups 较新版本的 groupby API pandas gt 0 23 提供了此 未记录的
  • IndexError - 具有匀称形状的笛卡尔 PolygonPatch

    我曾经使用 shapely 制作一个圆圈并将其绘制在之前填充的图上 这曾经工作得很好 最近 我收到索引错误 我将代码分解为最简单的操作 但它甚至无法执行最简单的循环 import descartes import shapely geome
  • python从二进制文件中读取16字节长的双精度值

    我找到了蟒蛇struct unpack 读取其他程序生成的二进制数据非常方便 问题 如何阅读16 字节长双精度数出二进制文件 以下 C 代码将 1 01 写入二进制文件三次 分别使用 4 字节浮点型 8 字节双精度型和 16 字节长双精度型
  • 定义在文本小部件中双击时选择哪些字符

    在 Windows 上 双击文本小部件中的单词也将选择连接的标点符号 有什么方法可以定义您想要选择的角色吗 tcl wordchars该变量的值是一个正则表达式 可以设置它来控制什么被视为 单词 字符 例如 通过双击 Tk 中的文本来选择单

随机推荐

  • 在iOS7半透明导航栏中获取正确的颜色

    如何为 iOS 7 中的半透明导航栏获得正确的颜色 导航栏只是将给定的颜色调整为更亮的颜色 更改颜色的亮度或饱和度也无法提供正确的结果 有人有同样的烦恼吗 看看 Facebook 它似乎以某种方式起作用 他们有自己的颜色和半透明的导航栏 编
  • SSIS 循环遍历 Excel 工作表

    我正在使用SSIS2012 我试图将大约25个excel文件 每个文件包含大约70个 变量 表 导入到SQLserver2008中 我已经构建了它 以便它将循环遍历所有 Excel 工作表并导入第一个工作表 但这没有用 我如何循环所有 Ex
  • 将文件直接上传到 GAE 应用的 Google Cloud Storage

    我正在考虑从 Blobstore 切换到 Google Cloud Storage 以处理项目中的图像上传等问题 因为 Google 称 Blobstore 为 取代 在 Blobstore 中 多部分表单将直接提交 上传 到 Blobst
  • 在 ansible playbook 中使用 gitlab-ci vars

    我想使用 Ansible playbook 在 docker 容器内设置远程环境 该剧本将从 gitlab ci 运行 其中包含我在 Gitlab CI CD 配置中设置的变量 我怎样才能做到这一点 这是我想使用的模板 我该如何设置user
  • C 中的常量返回类型

    我正在阅读一些代码示例 它们返回了 const int 当我尝试编译示例代码时 出现了有关返回类型冲突的错误 所以我开始搜索 认为 const 是问题所在 当我删除它时 代码工作正常 不仅可以编译 而且按预期工作 但我从未能够找到专门与 c
  • 查找字符串中长度最大的所有单词

    我想从字符串中找到长度最大的所有单词 目前 结果只是第一个长度最大的 jumped1 而我想要它们全部 jumped1 jumped2 我该如何调整以下内容 function test str var newStr str split va
  • 同步多个 UITableView 实例的滚动位置

    我有一个项目 我需要在其中显示多个UITableViewiPad 上同一视图内的实例 它们也恰好被轮换 但我相当确定这是无关紧要的 用户应该不知道视图是由多个表视图组成的 因此 我想做到这一点 以便当我滚动一个表视图时 其他表视图也会同时滚
  • Hibernate Envers:跟踪 OneToMany 关系拥有方的修订

    我有两个经过审计的实体 A 和 B 实体 A 拥有实体 B 的集合 注释为一对多关系 将 A 的新实例插入数据库时 A 和 B 的所有行都处于同一修订版 假设为修订版 1 然后 A 上有更新 仅影响实体 B 的实例 级联类型为合并 因此 更
  • 如何使用 Puppeteer 访问 React 事件处理程序

    我不完全确定我明白我的要求 我希望有人能解释一下 我正在尝试在 NodeJS 上使用 Puppeteer 抓取网站 我已经选择了我需要的元素并访问它的属性 但是 我无法访问我需要的属性来提取我想要的信息 我想要的信息在下面的绿色框中 但是我
  • 为什么即使使用前向声明,我也不能在 BEGIN 块中调用稍后定义的 sub?

    这有效 use strict X xxxxxx sub X print shift 这会产生一个错误 use strict BEGIN X xxxxxx sub X print shift Error Undefined subroutin
  • 通过 REST 在超级账本上部署链代码时出现“获取链代码包字节时出错”

    我正在尝试通过 POST REST 在 hyperledger Bluemix 服务 上部署链码 链码 查询规范 jsonrpc 2 0 方法 部署 参数 类型 1 chaincodeID 路径 https github com romeo
  • 翻译微风验证消息

    改进我的示例 了解如何使用获得的元数据在淘汰赛中创建验证规则 http stackoverflow com questions 13662446 knockout validation using breeze utility 现在我使用微
  • 防止 GDB 单步执行函数(或文件)

    我有一些像这样的 C 代码 我正在使用 GDB 逐步执行 void foo int num void main Baz baz foo baz get 当我在main 我想步入foo 但我想跨过去baz get The GDB docs说
  • 如何推送(即刷新)发送到 TCP 流的数据

    RFC 793说TCP定义了一个 推送 函数来确保接收者收到数据 有时用户需要确保他们拥有的所有数据 提交给TCP已经传输了 为此目的一推 函数已定义 确保提交给 TCP 的数据是 实际传输的发送用户表明它应该是 推送给接收用户 推送会导致
  • 报亭应用程序需要推送通知吗?

    如果我提交一个不使用推送通知的报刊亭应用程序 而是在每次用户启动该应用程序时向我的服务器查询新内容 苹果会拒绝我的应用程序吗 IE 用户是否期望在所有报亭应用上推送 Thanks 不 Apple 的指南并不强制要求使用推送通知 并且您的应用
  • 在服务内调用 getSystemService

    我正在尝试编写一项在 Gear Live 上获取心率的服务 遵循此处的问题从 传感器 Samsung Gear Live 获取心率 如果我把这部分 Log d TAG prepare to call getSystemService mSe
  • 使用 Nodejs 和 pug 进行客户端模板化

    我正在构建一个网络应用程序 它在客户端构建了动态小部件 目前我使用nodejs和pug作为我的服务器端模板库 我喜欢pug的简单性 我希望在服务器上有一系列小的 pug 文件 客户端可以将其用作构建块来构造用户所需的小部件 我尝试使用此处找
  • 单击:如何将操作应用于所有命令和子命令,但允许命令选择退出?

    我有一个案例 我想自动运行一个常用函数 check upgrade 对于我的大多数单击命令和子命令 但在少数情况下我不想运行它 我想我可以有一个可以添加的装饰器 例如 bypass upgrade check 对于命令 其中check up
  • MongoDB 将字符串类型转换为浮点类型

    按照这里的建议MongoDB 如何更改字段的类型 我尝试更新我的集合以更改字段的类型及其值 这是更新查询 db MyCollection find ProjectID 44 Cost exists true forEach function
  • tf.constant 和 tf.placeholder 的行为不同

    我想将 tf metrics 包装在 Sonnet 模块中以测量每个批次的性能 以下是我所做的工作 import tensorflow as tf import sonnet as snt class Metrics snt Abstrac