堆叠 LSTM 网络中每个 LSTM 层的输入是什么?

2024-01-12

我在理解堆叠 LSTM 网络中各层的输入输出流时遇到一些困难。假设我创建了一个如下所示的堆叠 LSTM 网络:

# parameters
time_steps = 10
features = 2
input_shape = [time_steps, features]
batch_size = 32

# model
model = Sequential()
model.add(LSTM(64, input_shape=input_shape,  return_sequences=True))
model.add(LSTM(32,input_shape=input_shape))

其中我们的堆叠 LSTM 网络由 2 个 LSTM 层组成,分别具有 64 个和 32 个隐藏单元。在这种情况下,我们期望在每个时间步,第一个 LSTM 层 -LSTM(64) - 将作为输入传递给第二个 LSTM 层 -LSTM(32) - 一个大小为[batch_size, time-step, hidden_unit_length],它表示第一个 LSTM 层在当前时间步的隐藏状态。让我困惑的是:

  1. 第二个 LSTM 层 -LSTM(32)- 是否接收为X(t)(作为输入)第一层的隐藏状态 -LSTM(64)- 其大小为[batch_size, time-step, hidden_unit_length]并将其传递到它自己的隐藏网络(在本例中由 32 个节点组成)?
  2. 如果第一个是真的,为什么input_shape当第二层仅处理第一层的隐藏状态时,第一个 -LSTM(64)- 和第二个 -LSTM(32)- 是相同的吗?在我们的例子中不应该有input_shape设置为[32, 10, 64]?

I found the LSTM visualization below very helpful (found here https://towardsdatascience.com/animated-rnn-lstm-and-gru-ef124d06cf45) but it doesn't expand on stacked-lstm networks: LSTM workings

任何帮助将不胜感激。 谢谢!


The input_shape仅第一层需要。后续层将前一层的输出作为其输入(因此它们的input_shape参数值被忽略)

下面的模型

model = Sequential()
model.add(LSTM(64, return_sequences=True, input_shape=(5, 2)))
model.add(LSTM(32))

代表以下架构

你可以从中验证它model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm_26 (LSTM)               (None, 5, 64)             17152     
_________________________________________________________________
lstm_27 (LSTM)               (None, 32)                12416     
=================================================================

更换线路

model.add(LSTM(32))

with

model.add(LSTM(32, input_shape=(1000000, 200000)))

仍然会给你相同的架构(使用验证model.summary())因为input_shape被忽略,因为它将前一层的张量输出作为输入。

如果您需要如下所示的序列到序列架构

你应该使用以下代码:

model = Sequential()
model.add(LSTM(64, return_sequences=True, input_shape=(5, 2)))
model.add(LSTM(32, return_sequences=True))

应该返回一个模型

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm_32 (LSTM)               (None, 5, 64)             17152     
_________________________________________________________________
lstm_33 (LSTM)               (None, 5, 32)             12416     
=================================================================
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

堆叠 LSTM 网络中每个 LSTM 层的输入是什么? 的相关文章

随机推荐

  • +CNMI命令:如何接收通知并保存到SIM卡收到的短信

    我需要收到成功发送的通知 我读了很多并尝试过 我有 GSM 调制解调器中兴K4510Z 我没有收到任何通知或保存到 SIM 卡 在我的测试中 我以为我的SIM卡坏了 所以我尝试AT CMGW将临时消息写入 SIM 卡 它成功并存在 所以最后
  • 仅使用下划线检查是否存在重复的数组对

    我想知道如何检查数组中的重复值对是否作为 javascript 中较大数组的一部分存在 你可以看到有一对重复的 1 2 所以函数应该返回true i e var arr 1 2 3 4 5 6 7 8 9 10 11 12 13 14 1
  • 如何让两个物体碰撞后粘在一起?

    我真的很困惑 我可以成功检测到碰撞 但我无法使参与碰撞的两个物体粘在一起 这是我的联系监听器 world setContactListener listener listener new ContactListener Override p
  • 获取 Android 通知以横幅形式显示

    我相当广泛地研究了各种术语 横幅 弹出 通知类型 但我似乎无法清楚地了解我 认为 的一个非常常见的问题 因此 如果由于缺乏术语而导致我错过了一个非常明显的解决方案 请提出建议 问题是这样的 我希望 Android 通知显示为从屏幕顶部掉落的
  • 如何自动填充 SQLAlchemy 数据库字段? (Flask-SQLAlchemy)

    我有一个简单的用户模型 定义如下 models py from datetime import datetime from myapp import db class User db Model id db Column db Intege
  • 在 Cakephp 中使用 $this->Auth 获取关联模型

    我正在使用 CakePHP 2 0 的集成 Auth 组件 我有以下表格 Users Groups Profiles 我的模型关系如下 User belongsTo Group User hasMany Profiles 登录该站点时 我注
  • 删除级联时是否有“反向”选项?

    假设我在 SQL Server 中有以下数据库 CREATE TABLE Order ID BIGINT IDENTITY 1 1 CONSTRAINT PK Order PRIMARY KEY CLUSTERED ID CREATE TA
  • 回收时注销

    在生产环境中 我有一个 IIS 托管的 ASP NET 应用程序 实际上是许多 Web 应用程序 每个应用程序都会消耗大量内存 但目前限制它的唯一方法是回收 nHibernate 似乎正在泄漏内存 并且它正在创建大量字符串集合 问题是 在回
  • 在 hashmap android 中添加 Arraylist> 中的值

    我必须获取数据列表 所以我使用了字符串的数组列表和列表 这里如何在地图上添加值 我使用了下面的代码 static final String KEY TITLE Category static final String KEY ARTICLE
  • 对 Angular2 中的对象数组进行排序

    我在 Angular2 中对对象数组进行排序时遇到问题 该对象看起来像 name t10 ts 1476778297100 value 32 339264 xid DP 049908 name t17 ts 1476778341100 va
  • ZF2 - 需要在特定条件失败时显示特定错误消息

    我正在使用 ZF2 表单验证 我必须验证两个字段 用户名 和 密码 一切正常 但我收到类似的消息 Please enter username Username can not be less than 3 characters Please
  • 在 Eclipse 启动时禁用插件

    我刚刚为 Eclipse 安装了一个插件 但结果 Eclipse 将不再启动 它说 有一个错误 或一些此类无信息的消息 如何在不加载插件的情况下启动 Eclipse 以便我可以实际卸载有问题的软件 正如另一个人提到的 您可以尝试 clean
  • 处理 R 中冲突的命名空间(不同包中的相同函数名称):重置包命名空间的优先级

    不同包的命名空间之间的名称冲突R可能是危险的 并且使用package function不幸的是没有普遍化R 是否有一个函数可以重置包命名空间相对于当前加载的所有其他命名空间的优先级 我们当然可以detach然后重新加载包 但是没有其他更实用
  • 如何使用 Google App Engine 重定向所有 URL

    我该如何配置app yaml文件将所有 URL 重定向到另一个 URL 例如我想要http example appspot com hello or http example appspot com hello28928723重定向到htt
  • 有 CSS 父选择器吗?

    我该如何选择 li 是锚元素的直接父元素吗 举个例子 我的 CSS 应该是这样的 li lt a active property value 显然 有多种方法可以使用 JavaScript 实现此目的 但我希望 CSS Level 2 本身
  • 使用 terraform 获取金库秘密值

    我正在使用带有 consul 的保管库服务器作为存储后端 并尝试使用 terraform 中的保管库提供程序获取密码值 但它并没有获得它的价值 我将我的秘密存储在位置秘密 实例中 main tf provider vault address
  • 如何加速 Mongodump,转储未完成

    在尝试使用来自大约 50 亿个数据库的查询来运行数据库转储时 进度条时间似乎表明此转储不会在任何合理的时间 100 多天 内完成 大约 22 小时后 查询似乎以 0 结束后也冻结了 之后的行是metadata json 行 转储行是 mon
  • 复制同名属性的简单代码?

    我有一个old这个问题在我脑海里停留了很长时间 当我在 Spring 中编写代码时 有很多 DTO 域对象的脏代码和无用代码 对于语言级别 我对 Java 毫无希望 但在 Kotlin 中看到了一些曙光 这是我的问题 Style 1我们通常
  • 在 pyqt4 中旋转像素图会产生不需要的翻译

    我正在尝试编写一个简单的应用程序 在按下按钮时旋转 png 图像 我一切正常 只是当图像旋转时 它偏离了东南方向的中心 我本以为它不是绕着中心旋转 但每旋转 45 度它就会回到原点 这很奇怪 对于一个关键事件 我只是简单地调用 pixmap
  • 堆叠 LSTM 网络中每个 LSTM 层的输入是什么?

    我在理解堆叠 LSTM 网络中各层的输入输出流时遇到一些困难 假设我创建了一个如下所示的堆叠 LSTM 网络 parameters time steps 10 features 2 input shape time steps featur