使用 argmax 从 Tensor 获取值

2024-04-03

我有一个Tensor形状的(60, 128, 30000)。我想得到的值argmax of the 30000方面 (axis=2)。 此代码是一个示例:

tensor = tf.random.uniform((60, 128, 30000)) # shape (60, 128, 30000)
argmax = tf.argmax(tensor, axis=2) # shape (60, 128) --> max of each 30000

# do something to get every values of 30000
# argmax output (index)
<tf.Tensor: shape=(60, 128), dtype=int64, numpy=
array([[ 3229,  3079,  8360, ...,  1005, 16460,   872],
       [17808,  1253, 25476, ..., 16130,  3479,  3479],
       [27717, 25429, 18808, ...,  9787,  2603, 24011],
       ...,
       [25429, 25429,  5647, ..., 18451, 12453, 12453],
       [ 7361, 13463, 15864, ..., 18839, 12453, 12453],
       [ 4750, 25009, 11888, ...,  5647,  1993, 18451]], dtype=int64)>

# Desired output: each values of every index

With argmax,我得到了它们的索引数组,而不是它们的值。如何获得相同形状的数组(60, 128)他们的价值观?


你将不得不使用tf.meshgrid and tf.gather_nd实现你想要的:

tensor = tf.random.uniform((60, 128, 30000)) # shape (60, 128, 30000)
argmax = tf.argmax(tensor, axis=2)

ij = tf.stack(tf.meshgrid(
    tf.range(tensor.shape[0], dtype=tf.int64), 
    tf.range(tensor.shape[1], dtype=tf.int64),
                              indexing='ij'), axis=-1)

gather_indices = tf.concat([ij, tf.expand_dims(argmax, axis=-1)], axis=-1)
result = tf.gather_nd(tensor, gather_indices)
tf.print(result.shape)
TensorShape([60, 128])

Why is tf.meshgrid必要的?因为argmax确实包含您的索引,但形状错误。功能tf.gather_nd需要知道应该从 3D 张量中提取值的确切位置。这tf.meshgrid函数创建两个一维数组的矩形网格,表示第一维和第二维的张量索引。

import tensorflow as tf

tensor = tf.random.uniform((2, 5, 3))
argmax = tf.argmax(tensor, axis=2)

# result = tf.gather_nd(tensor, gather_ind) <-- Would not work because arxmax has the shape TensorShape([2, 5]) but  TensorShape([2, 5, 3]) is required
tf.print('Input tensor:\n', tensor, tensor.shape, '\nArgmax tensor:\n', argmax, argmax.shape)

i, j = tf.meshgrid(
    tf.range(tensor.shape[0], dtype=tf.int64), 
    tf.range(tensor.shape[1], dtype=tf.int64),
                              indexing='ij')

# You need to create a mesh grid to correctly index your tensor.

ij = tf.stack([i, j], axis=-1)
tf.print('Meshgrid:\n', i, j, summarize=-1)
tf.print('Stacked:\n', ij, summarize=-1)

gather_indices = tf.concat([ij, tf.expand_dims(argmax, axis=-1)], axis=-1)
tf.print('Gathered indices:\n', gather_indices, gather_indices.shape, summarize=-1)

result = tf.gather_nd(tensor, gather_indices)
tf.print('\nFinal result:\n', result, result.shape)
Input tensor:
 [[[0.889752269 0.243187189 0.601408958]
  [0.891950965 0.776625633 0.146243811]
  [0.136176467 0.743871331 0.762170076]
  [0.424416184 0.150568008 0.464055896]
  [0.308753 0.0792338848 0.383242]]

 [[0.741660118 0.49783361 0.935318112]
  [0.0616152287 0.0367363691 0.748341084]
  [0.397849679 0.765681744 0.502376914]
  [0.750188231 0.304993749 0.733741879]
  [0.31267941 0.778184056 0.546301]]] TensorShape([2, 5, 3]) 
Argmax tensor:
 [[0 0 2 2 2]
 [2 2 1 0 1]] TensorShape([2, 5])
Meshgrid:
 [[0 0 0 0 0]
 [1 1 1 1 1]] [[0 1 2 3 4]
 [0 1 2 3 4]]
Stacked:
 [[[0 0]
  [0 1]
  [0 2]
  [0 3]
  [0 4]]

 [[1 0]
  [1 1]
  [1 2]
  [1 3]
  [1 4]]]
Gathered indices:
 [[[0 0 0]
  [0 1 0]
  [0 2 2]
  [0 3 2]
  [0 4 2]]

 [[1 0 2]
  [1 1 2]
  [1 2 1]
  [1 3 0]
  [1 4 1]]] TensorShape([2, 5, 3])

Final result:
 [[0.889752269 0.891950965 0.762170076 0.464055896 0.383242]
 [0.935318112 0.748341084 0.765681744 0.750188231 0.778184056]] TensorShape([2, 5])

顺便说一句,您也可以考虑使用tf.math.top_k因为您想获得最后一个维度的最大值。该函数返回索引和值(您想要的):

tensor = tf.random.uniform((60, 128, 30000)) # shape (60, 128, 30000)
values, indices = tf.math.top_k(tensor,
                        k=1)
tf.print(tf.squeeze(values, axis=-1).shape)
TensorShape([60, 128])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 argmax 从 Tensor 获取值 的相关文章

  • 中断 Select 以添加另一个要在 Python 中监视的套接字

    我正在 Windows XP 应用程序中使用 TCP 实现点对点 IPC 我正在使用select and socketPython 2 6 6 中的模块 我有三个 TCP 线程 一个读取线程通常会阻塞select 一个通常等待事件的写入线程
  • 如何使用 opencv.omnidir 模块对鱼眼图像进行去扭曲

    我正在尝试使用全向模块 http docs opencv org trunk db dd2 namespacecv 1 1omnidir html用于对鱼眼图像进行扭曲处理Python 我正在尝试适应这一点C 教程 http docs op
  • Python 中的舍入浮点问题

    我遇到了 np round np around 的问题 它没有正确舍入 我无法包含代码 因为当我手动设置值 而不是使用我的数据 时 返回有效 但这是输出 In 177 a Out 177 0 0099999998 In 178 np rou
  • 用枢轴点拟合曲线 Python

    我有下面的图 我想用 2 条线来拟合它 使用 python 我设法适应上半部分 def func x a b x np array x return a x b popt pcov curve fit func up x up y 我想用另
  • 使用 Python 从文本中删除非英语单词

    我正在 python 上进行数据清理练习 我正在清理的文本包含我想删除的意大利语单词 我一直在网上搜索是否可以使用像 nltk 这样的工具包在 Python 上执行此操作 例如给出一些文本 Io andiamo to the beach w
  • 使用 kivy textinput 的 'input_type' 属性的问题

    您好 我在使用 kivy 的文本输入小部件的 input type 属性时遇到问题 问题是我制作了两个自定义文本输入 其中一个称为 StrText 其中设置了 input type text 然后是第二个文本输入 名为 NumText 其
  • 独立滚动矩阵的行

    我有一个矩阵 准确地说 是 2d numpy ndarray A np array 4 0 0 1 2 3 0 0 5 我想滚动每一行A根据另一个数组中的滚动值独立地 r np array 2 0 1 也就是说 我想这样做 print np
  • 使用字典映射数据帧索引

    为什么不df index map dict 工作就像df column name map dict 这是尝试使用index map的一个小例子 import pandas as pd df pd DataFrame one A 10 B 2
  • 如何使用 Pandas、Numpy 加速 Python 中的嵌套 for 循环逻辑?

    我想检查一下表的字段是否TestProject包含了Client端传入的参数 嵌套for循环很丑陋 有什么高效简单的方法来实现吗 非常感谢您的任何建议 def test parameter a list parameter b list g
  • 如何将张量流模型部署到azure ml工作台

    我在用Azure ML Workbench执行二元分类 到目前为止 一切正常 我有很好的准确性 我想将模型部署为用于推理的 Web 服务 我真的不知道从哪里开始 azure 提供了这个doc https learn microsoft co
  • 如何使用 pybrain 黑盒优化训练神经网络来处理监督数据集?

    我玩了一下 pybrain 了解如何生成具有自定义架构的神经网络 并使用反向传播算法将它们训练为监督数据集 然而 我对优化算法以及任务 学习代理和环境的概念感到困惑 例如 我将如何实现一个神经网络 例如 1 以使用 pybrain 遗传算法
  • pyspark 将 twitter json 流式传输到 DF

    我正在从事集成工作spark streaming with twitter using pythonAPI 我看到的大多数示例或代码片段和博客是他们从Twitter JSON文件进行最终处理 但根据我的用例 我需要所有字段twitter J
  • 加快网络抓取速度

    我正在使用一个非常简单的网络抓取工具抓取 23770 个网页scrapy 我对 scrapy 甚至 python 都很陌生 但设法编写了一个可以完成这项工作的蜘蛛 然而 它确实很慢 爬行 23770 个页面大约需要 28 小时 我看过scr
  • Python3 在 DirectX 游戏中移动鼠标

    我正在尝试构建一个在 DirectX 游戏中执行一些操作的脚本 除了移动鼠标之外 我一切都正常 是否有任何可用的模块可以移动鼠标 适用于 Windows python 3 Thanks I used pynput https pypi or
  • import matplotlib.pyplot 给出 AttributeError: 'NoneType' 对象没有属性 'is_interactive'

    我尝试在 Pycharm 控制台中导入 matplotlib pyplt import matplotlib pyplot as plt 然后作为回报我得到 Traceback most recent call last File D Pr
  • 如何使用原始 SQL 查询实现搜索功能

    我正在创建一个由 CS50 的网络系列指导的应用程序 这要求我仅使用原始 SQL 查询而不是 ORM 我正在尝试创建一个搜索功能 用户可以在其中查找存储在数据库中的书籍列表 我希望他们能够查询 书籍 表中的 ISBN 标题 作者列 目前 它
  • 根据列 value_counts 过滤数据框(pandas)

    我是第一次尝试熊猫 我有一个包含两列的数据框 user id and string 每个 user id 可能有多个字符串 因此会多次出现在数据帧中 我想从中导出另一个数据框 一个只有那些user ids列出至少有 2 个或更多string
  • 为什么 Pickle 协议 4 中的 Pickle 文件是协议 3 中的两倍,而速度却没有任何提升?

    我正在测试 Python 3 4 我注意到 pickle 模块有一个新协议 因此 我对 2 个协议进行了基准测试 def test1 pickle3 open pickle3 wb for i in range 1000000 pickle
  • python import inside函数隐藏现有变量

    我在我正在处理的多子模块项目中遇到了一个奇怪的 UnboundLocalError 分配之前引用的局部变量 问题 并将其精简为这个片段 使用标准库中的日志记录模块 import logging def foo logging info fo
  • 如何应用一个函数 n 次? [关闭]

    Closed 这个问题需要细节或清晰度 help closed questions 目前不接受答案 假设我有一个函数 它接受一个参数并返回相同类型的结果 def increment x return x 1 如何制作高阶函数repeat可以

随机推荐