在 Tensorflow 中检索 LSTM 序列的最后一个值

2024-03-30

我有不同长度的序列,想在 Tensorflow 中使用 LSTM 进行分类。对于分类,我只需要每个序列最后一个时间步长的 LSTM 输出。

max_length = 10
n_dims = 2
layer_units = 5
input = tf.placeholder(tf.float32, [None, max_length, n_dims])
lengths =  tf.placeholder(tf.int32, [None])
cell = tf.nn.rnn_cell.LSTMCell(num_units=layer_units, state_is_tuple=True)

sequence_outputs, last_states = tf.nn.dynamic_rnn(cell, sequence_length=lengths, inputs=input)

我想用 NumPy 表示法得到:output = sequence_outputs[:,lengths]

有什么方法或解决方法可以在 Tensorflow 中实现这种行为吗?

- -更新 - -

关注此帖子如何从 TensorFlow 中的 3-D 张量中选择行? https://stackoverflow.com/questions/36088277/how-to-select-rows-from-a-3-d-tensor-in-tensorflow看来可以有效地解决问题tf.gather并操纵指数。唯一的要求是必须提前知道批量大小。以下是对所提到的帖子针对这个具体问题的改编:

max_length = 10
n_dims = 2
layer_units = 5
batch_size = 2
input = tf.placeholder(tf.float32, [batch_size, max_length, n_dims])
lengths =  tf.placeholder(tf.int32, [batch_size])
cell = tf.nn.rnn_cell.LSTMCell(num_units=layer_units, state_is_tuple=True)

sequence_outputs, last_states = tf.nn.dynamic_rnn(cell,
                                                  sequence_length=lengths, inputs=input)

#Code adapted from @mrry response in StackOverflow:
#https://stackoverflow.com/questions/36088277/how-to-select-rows-from-a-3-d-tensor-in-tensorflow
rows_per_batch = tf.shape(input)[1]
indices_per_batch = 1

# Offset to add to each row in indices. We use `tf.expand_dims()` to make 
# this broadcast appropriately.
offset = tf.range(0, batch_size) * rows_per_batch

# Convert indices and logits into appropriate form for `tf.gather()`. 
flattened_indices = lengths - 1 + offset
flattened_sequence_outputs = tf.reshape(self.sequence_outputs, tf.concat(0, [[-1],
                             tf.shape(sequence_outputs)[2:]]))

selected_rows = tf.gather(flattened_sequence_outputs, flattened_indices)
last_output  = tf.reshape(selected_rows,
                          tf.concat(0, [tf.pack([batch_size, indices_per_batch]),
                          tf.shape(self.sequence_outputs)[2:]]))

@petrux 选项(获取 TensorFlow 中动态_rnn 的最后输出 https://stackoverflow.com/questions/41273361/get-th-last-output-of-a-dynamic-rnn-in-tensorflow/41273737#41273737)似乎也可以工作,但在 for 循环中构建列表的需要可能不太优化,尽管我没有执行任何基准测试来支持此语句。


This https://stackoverflow.com/a/41273737/1861627可能是一个答案。我不认为有任何与你指出的 NumPy 表示法类似的东西,但效果是一样的。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 Tensorflow 中检索 LSTM 序列的最后一个值 的相关文章

随机推荐

  • SQLite 到核心数据迁移

    我在 App Store 上有一个实时应用程序 它使用SQLite作为数据库 现在我要实施下一个更新核心数据在应用程序中加载 sqlite 文件中的所有数据 而不会破坏应用程序 我一直在阅读教程 但没有多大帮助 我不知道如何继续 请为我指出
  • PuLP LpStatus=Undefined 的实际含义是什么?

    当我向问题添加特定约束时 解决后问题的 LpStatus 更改为 未定义 没有此约束 它是 最佳 在这个页面的顶部 显示了退货状态的可能性 但似乎没有解释它们的含义 谁能解释一下 未定义 状态是什么意思 它类似于指定约束时的语法错误吗 求解
  • SVG 文件:使用 Inkscape 将 PNG 文件转换为 SVG 文件后如何获取路径标记数据?

    我问是否有一个工具可以将 PNG 文件转换为 SVG 文件 我可以在其中获取路径标记 我尝试过使用 Inkscape 但是当我编辑 SVG 文件时 我找不到路径标记 只有 二进制 数据 SVG 文件路径标记示例 https www w3 o
  • 如何从命令行启动 Mac OS X 应用程序?

    open a 不是想要的答案 因为我想自动调试 Mac OS X 应用程序 这意味着如果有人可以给出这样的命令行会更好 程序 参数 格式 所以ltrace机制可以使 程序 作为调试目标并采取 args 作为输入 我尝试过像 Applicat
  • Xamarin Forms PCL - 干净、简单的网络请求方法?

    我正在使用便携式类库构建 Android iOS xamarin 表单应用程序 我正在寻找在 PCL 项目中执行此示例的最佳方法 https msdn microsoft com en us library 456dfw4f v vs 11
  • 将混合数字、分数和整数的字符向量转换为数值

    我正在尝试编写一个 R 函数来将分数和带分数转换为小数 例如 mixedToFloat lt function x x lt sub x fixed TRUE return unlist lapply x function x eval p
  • 熊猫,融化,未融化保存指数

    我有一张客户表 铜 和资产分配表 资产 A 1 2 3 4 5 6 idx coper1 coper2 coper3 cols asset1 asset2 df pd DataFrame A index idx columns cols 所
  • 如何运行“sbt hello, world”?

    我正在尝试了解 Scala SBT 并正在浏览http www scala sbt org 0 13 docs Hello html http www scala sbt org 0 13 docs Hello html 第一步进展顺利 我
  • 使用端口 80 上的 runserver 在没有 Apache 的情况下运行 Django,并且可以在 LAN 外部访问

    在调试模式下 我可以运行 django web 该 web 可以通过公共 局域网内 访问 python manage py runserver 0 0 0 0 8000 那么 是否可以像通常的网络服务器一样直接在端口 80 可能带有域 上运
  • 如何使用 wix 将多个元素添加到 XML 配置文件中?

    我正在尝试使用 Wix 编辑 XML 文件 我正在使用与 Wix 3 7 捆绑在一起的 WixUtilExtension xml 文件是在 Visual Studio 2010 中为 C 应用程序创建的设置文件 在此文件中 我使用一个用于在
  • matplotlib 仅显示一组 10 个图形中的一个,就像幻灯片一样

    I have a set of 10 graphs based on X Y pairs In this example only 3 Displaying one graph is easy same to all graphs in t
  • 如何将字节数组转换为图像文件?

    我在我的 MVC Web 应用程序中浏览并上传了 png jpg 文件 我已将此文件作为 byte 存储在我的数据库中 现在我想读取 byte 并将其转换为原始文件 我怎样才能做到这一点 创建一个内存流 http msdn microsof
  • C 标准库和 C POSIX 库的区别

    我对 C standard lib 和 C POSIX lib 有点困惑 因为我发现 C POSIX lib 中定义的许多头文件也是 C standard lib 的一部分 所以 我假设 C standard lib 是由ANSI C组织定
  • iOS 应用程序仅在未调试时崩溃

    我正在使用 testflight 来测试我的应用程序 并且只有当应用程序是为临时构建并通过测试飞行分发时才会发生崩溃 相关崩溃报告详细信息如下 Date Time 2012 06 11 09 00 34 638 0800 OS Versio
  • PowerShell:如何设置文化?

    我尝试过了Set Culture CultureInfo vi VN但Powershell并没有改变我设定的文化 我通过打开Powershell ISE进行测试 看到我设置的文化已成功更改 如何使用 Powershell 更改我设定的文化
  • C# .net 相当于 HTTP_RAW_POST_DATA?

    想要在 C 中模仿 php 代码 我想捕获从以下 Flash Actionscript 发布的原始图像数据 function onSaveJPG e Event void var myEncoder JPGEncoder new JPGEn
  • PyMongo 不会迭代集合

    我在 Python PyMongo 中有奇怪的行为 dbh self connection test first dbh test 1 second dbh test 2 first collection records first fin
  • 如何将 C++ 程序连接到 WCF 服务?

    在我工作的地方 有一些用 C 编写的软件 还有一些用 C 编写的软件 最重要的 不久前 我们认为通过 Web 服务发送堆栈跟踪和异常信息来跟踪软件中任何可能的问题是一个好主意 因此 我使用了 WCF 服务 它获取信息并将其存储在数据库中并自
  • WCF 是否支持点对点实现?

    我正在尝试在 LAN 内实现点对点消息传递和文件共享实用程序 那么 WCF 支持 p2p 吗 有人尝试过通过 WCF 进行文件共享吗 是的 它确实 请参见如何在对等网络中设计状态共享 http msdn microsoft com en u
  • 在 Tensorflow 中检索 LSTM 序列的最后一个值

    我有不同长度的序列 想在 Tensorflow 中使用 LSTM 进行分类 对于分类 我只需要每个序列最后一个时间步长的 LSTM 输出 max length 10 n dims 2 layer units 5 input tf place