如何在tf.data.Dataset.map()中使用Keras的predict_on_batch?

2024-02-06

我想找到一种使用 Keras 的方法predict_on_batch inside tf.data.Dataset.map() in TF2.0.

假设我有一个 numpy 数据集

n_data = 10**5
my_data    = np.random.random((n_data,10,1))
my_targets = np.random.randint(0,2,(n_data,1))

data = ({'x_input':my_data}, {'target':my_targets})

and a tf.keras model

x_input = Input((None,1), name = 'x_input')
RNN     = SimpleRNN(100,  name = 'RNN')(x_input)
dense   = Dense(1, name = 'target')(RNN)

my_model = Model(inputs = [x_input], outputs = [dense])
my_model.compile(optimizer='SGD', loss = 'binary_crossentropy')

我可以批量创建一个dataset with

dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(10)
prediction_dataset = dataset.map(transform_predictions)

where transform_predictions是一个用户定义的函数,它从以下位置获取预测predict_on_batch

def transform_predictions(inputs, outputs):
    predictions = my_model.predict_on_batch(inputs)
    # predictions = do_transformations_here(predictions)
    return predictions

这给出了一个错误predict_on_batch:

AttributeError: 'Tensor' object has no attribute 'numpy'

据我所理解,predict_on_batch需要一个 numpy 数组,并且它正在从数据集中获取一个张量对象。

似乎一种可能的解决方案是包装predict_on_batch在`tf.py_function中,尽管我也无法让它工作。

有谁知道如何做到这一点?


Dataset.map() 返回<class 'tensorflow.python.framework.ops.Tensor'>它没有 numpy() 方法。

迭代数据集返回<class 'tensorflow.python.framework.ops.EagerTensor'>它有一个 numpy() 方法。

将急切的张量提供给 Predict() 系列方法效果很好。

你可以尝试这样的事情:

dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(10)

for x,y in dataset:
    predictions = my_model.predict_on_batch(x['x_input'])
    #or 
    predictions = my_model.predict_on_batch(x)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在tf.data.Dataset.map()中使用Keras的predict_on_batch? 的相关文章

  • Python Popen 与 psexec 挂起 - 不良结果

    我对 subprocess Popen 和我认为是管道的问题有疑问 我有以下代码块 从 cli 运行时 100 都不会出现问题 p subprocess Popen psexec serverName get cmd c ver echo
  • 使用 python 进行串行数据记录

    Intro 我需要编写一个小程序来实时读取串行数据并将其写入文本文件 我在读取数据方面取得了一些进展 但尚未成功地将这些信息存储在新文件中 这是我的代码 from future import print function import se
  • python future 和元组解包

    实现像使用 future 进行元组解包这样的事情的优雅 惯用的方法是什么 我有这样的代码 a b c f x y g a b z h y c 我想将其转换为使用期货 理想情况下我想写一些类似的东西 a b c ex submit f x y
  • Python模块可以访问英语词典,包括单词的定义[关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 我正在寻找一个 python 模块 它可以帮助我从英语词典中获取单词的定义 当然有enchant 这可以帮助我检查该单词是否存在于英语中
  • 在 Python distutils 中从 setup.py 查找脚本目录的正确方法?

    我正在分发一个具有以下结构的包 mymodule mymodule init py mymodule code py scripts script1 py scripts script2 py The mymodule的子目录mymodul
  • 如何计算numpy数组中元素的频率?

    我有一个 3 D numpy 数组 其中包含重复的元素 counterTraj shape 13530 1 1 例如 counterTraj 包含这样的元素 我只显示了几个元素 array 136 129 130 103 102 101 我
  • Pandas 数据帧到 numpy 数组 [重复]

    这个问题在这里已经有答案了 我对 Python 很陌生 经验也很少 我已经设法通过复制 粘贴和替换我拥有的数据来使一些代码正常工作 但是我一直在寻找如何从数据框中选择数据 但无法理解这些示例并替换我自己的数据 总体目标 如果有人真的可以帮助
  • 以同步方式使用 FastAPI,如何获取 POST 请求的原始正文?

    在中使用 FastAPIsync not async模式 我希望能够接收 POST 请求的原始 未更改的正文 我能找到的所有例子都显示async代码 当我以正常同步方式尝试时 request body 显示为协程对象 当我通过发布一些内容来
  • 使用 Python pandas 计算调整后的成本基础(股票买入/卖出的投资组合分析)

    我正在尝试对我的交易进行投资组合分析 并尝试计算调整后的成本基础价格 我几乎尝试了一切 但似乎没有任何效果 我能够计算调整后的数量 但无法获得调整后的购买价格有人可以帮忙吗 这是示例交易日志原始数据 import pandas as pd
  • 对图像块进行多重处理

    我有一个函数必须循环遍历图像的各个像素并计算一些几何形状 此函数需要很长时间才能运行 在 24 兆像素图像上大约需要 5 小时 但似乎应该很容易在多个内核上并行运行 然而 我一生都找不到一个有据可查 解释充分的例子来使用 Multiproc
  • 从 python 发起 SSH 隧道时出现问题

    目标是在卫星服务器和集中式注册数据库之间建立 n 个 ssh 隧道 我已经在我的服务器之间设置了公钥身份验证 因此它们只需直接登录而无需密码提示 怎么办 我试过帕拉米科 它看起来不错 但仅仅建立一个基本的隧道就变得相当复杂 尽管代码示例将受
  • Numpy 过滤器平滑零区域

    我有一个 0 及更大整数的 2D numpy 数组 其中值代表区域标签 例如 array 9 9 9 0 0 0 0 1 1 1 9 9 9 9 0 7 1 1 1 1 9 9 9 9 0 2 2 1 1 1 9 9 9 8 0 2 2 1
  • 奇怪的 MySQL Python mod_wsgi 无法连接到 'localhost' (49) 上的 MySQL 服务器问题

    StackOverflow上也有类似的问题 但我还没有发现完全相同的情况 这是在使用 MySQL 的 OS X Leopard 机器上 一些起始信息 MySQL Server version 5 1 30 Apache 2 2 13 Uni
  • 首先对列表中最长的项目进行排序

    我正在使用 lambda 来修改排序的行为 sorted list key lambda item item lower len item 对包含元素的列表进行排序A1 A2 A3 A B1 B2 B3 B 结果是A A1 A2 A3 B
  • 将 matplotlib 颜色图集中在特定值上

    我正在使用 matplotlib 颜色图 seismic 绘制绘图 并且希望白色以 0 为中心 当我在不进行任何更改的情况下运行脚本时 白色从 0 下降到 10 我尝试设置 vmin 50 vmax 50 但在这种情况下我完全失去了白色 关
  • 将 JSON 对象传递给带有请求的 url

    所以 我想利用 Kenneth 的优秀请求模块 https github com kennethreitz requests 在尝试使用时偶然发现了这个问题自由库API http wiki freebase com wiki API 基本上
  • 使用 Firefox 绕过弹出窗口下载文件:Selenium Python

    我正在使用 selenium 和 python 来从中下载某些文件web page http www oceanenergyireland com testfacility corkharbour observations 我之前一直使用设
  • 使用 PyTorch 分布式 NCCL 连接失败

    我正在尝试使用 torch distributed 将 PyTorch 张量从一台机器发送到另一台机器 dist init process group 函数正常工作 但是 dist broadcast 函数中出现连接失败 这是我在节点 0
  • 无法在前端使用 JavaScript Fetch API 将文件上传到 FastAPI 后端

    我正在尝试弄清楚如何将图像发送到我的 API 并验证生成的token那是在header的请求 到目前为止 这就是我所处的位置 app post endreProfilbilde async def endreProfilbilde requ
  • 具有自定义值的 Django 管理外键下拉列表

    我有 3 个 Django 模型 class Test models Model pass class Page models Model test models ForeignKey Test class Question model M

随机推荐

  • 如何选择列表中所有无序的元素?

    这个问题源于评论里的讨论这个答案 https stackoverflow com questions 1390832 how to sort nearly sorted array in the fastest time possible
  • 如何使用executeReader()方法检索一个单元格的值

    我需要执行以下命令并将结果传递给标签 我不知道如何使用 Reader 来做到这一点 有人可以帮我吗 String sql SELECT FROM learer WHERE learer id index SqlCommand cmd new
  • 使用 CoreData 嵌套撤消组

    我想将撤消管理器添加到 coredata 支持的 iPhone 应用程序中 当用户尝试添加新对象 通过点击 按钮 时 我加载一个新的模式视图控制器并在 viewDidLoad 中启动一个新的撤消组 当用户按下 取消 按钮时 我想回滚 can
  • 删除 Spark 数据框中重复的所有记录

    我有一个包含多列的 Spark 数据框 我想找出并删除列中具有重复值的行 其他列可能不同 我尝试使用dropDuplicates col name 但它只会删除重复的条目 但仍会在数据框中保留一条记录 我需要的是删除最初包含重复条目的所有条
  • Google 街景中像素距地面的高度/标高

    我正在寻找谷歌街景中每个像素距地面的高度 我知道可以计算的几件事是 像素间距 https stackoverflow com questions 21591462 get heading and pitch from pixels on s
  • 删除特定的kafka消息

    我想指示 kafka 尽可能删除一条消息 如果使用键和日志压缩 可以将键设置为消息 ID 并将消息内容设置为 null 但我寻找更直接的东西 不依赖于设置密钥 例如通过消息 ID None
  • 如何在 NSMenuItem 内绘制内联样式标签(或按钮)

    当 App Store 有更新时 它会在菜单项中显示一个内联样式元素 如下面屏幕截图中的 1 new 另一个我们可以看到这种菜单的地方是10 10 Yosemite的分享菜单 当您安装任何添加新共享扩展的应用程序时 共享菜单中的 更多 项目
  • AWS SSO、Codecommit(GRC git 克隆链接)和 npm install

    单点登录 SSO 在 AWS 账户上实施 运行后aws sso login 使用 GRC 链接 克隆节点和存储库是可行的 然而 运行npm install在 repo 中会导致不同的错误 前任 包 json dependencies com
  • 如何处理极长的LSTM序列长度?

    我有一些数据以非常高的速率 大约每秒数百次 采样 对于任何给定实例 这会导致平均序列长度很大 约 90 000 个样本 整个序列有一个标签 我正在尝试使用 LSTM 神经网络将新序列分类为这些标签之一 多类分类 然而 使用具有如此大序列长度
  • 如何更改浮动占位符的角度材料表单字段中的字体大小

    下面是角材料的形状场 当占位符正常和浮动时 如何为占位符添加 2 个不同的自定义字体大小 字体大小 20px 正常时 字体大小 13px 当它浮起来并变小时
  • 推送路线时将对象作为 prop 传递

    该功能位于路由器视图之外的组件中 goToMarkets this router push path markets params stock this model 但该道具在 市场 组件中未定义 Router const routes p
  • 如何使用 es6 js 类表示法自动递增 id 值?

    我在 es6 中的类方面遇到一些问题 每次创建对象时 我都需要自动递增 id 值 真的不明白我如何声明变量 为 id 赋值 然后递增增量变量 class Rectangle constructor name width height x y
  • 将报告导出为 PDF 时更改字体

    我在用着贾斯珀软件工作室 5 2 我做了一份报告快递新字体 当我将其导出到 PDF 时 它会将字体更改为Arial 我只使用Studio工具 当我预览报告时一切正常 但当我导出时就会发生这种情况 我可以如何处理我的报告以导出快递新 font
  • 在 Maxima 列表中查找最大值和索引?

    我有一个 maxima 列表 例如 x 1 3 7 98 211 3 2 44 23 我需要找到列表的最大值以及最大值位于哪个位置 我唯一想到的是将列表重写为序列并应用 max 命令 max first x second x last x
  • 具有摊销 O(1) 删除和 O(log n) 搜索的数据结构

    我需要一个支持两种操作的数据结构 删除和搜索 现在 删除操作应该运行在摊销 O 1 时间 而搜索应该运行在O log n time 搜索操作应该如下工作 查找指定的值 如果它在这里 则返回值本身 否则 返回最接近的较大值 返回有序后继 这个
  • 具有输入数组的方法

    我想要一种方法 可以像 NSArray 一样放置所需数量的参数 id initWithObjects id firstObj NS REQUIRES NIL TERMINATION 然后我可以使用 NSArray array NSArray
  • 获取 Boto3 中具有特定标签和值的 EC2 实例列表

    如何使用标签和值过滤 AWS 实例boto3 import boto3 ec2 boto3 resource ec2 client boto3 client ec2 response client describe tags Filters
  • 找到svg形状的中心

    我在用着svgjs创造我的形状 如何找到 svg 形状的中心点并在那里添加元素 就我而言 是一个红点 我在文档中找不到任何有助于解决这种情况的方法或内容的信息 你可以使用该方法getBBox https docs webplatform o
  • 实体框架异常“底层提供程序在打开时失败”

    我创建了一个 Windows 服务 它侦听 TCP IP 端口并使用实体框架将接收到的数据保存在数据库中 大多数时候它工作正常 但有时会抛出异常 底层提供程序打开失败 在数据库中保存数据 这是我的异常详细信息 Exception 2 27
  • 如何在tf.data.Dataset.map()中使用Keras的predict_on_batch?

    我想找到一种使用 Keras 的方法predict on batch inside tf data Dataset map in TF2 0 假设我有一个 numpy 数据集 n data 10 5 my data np random ra