TF多层 LSTM 以及 State 之间的融合

2023-05-16

第一是实现多层的LSTM的网络;
第二是实现两个LSTM的state的concat操作, 分析 state 的结构.

对于第一个问题,之前一直没有注意过, 看下面两个例子:

在这里插入代码片
import tensorflow as tf

num_units = [20, 20]

#Unit1, OK
# X = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)
# X = tf.reshape(X, [-1, 5, 6])
# multi_rnn = [tf.nn.rnn_cell.BasicLSTMCell(num_units=units) for units in num_units]
# lstm_cells = tf.contrib.rnn.MultiRNNCell(multi_rnn)
# output, state = tf.nn.dynamic_rnn(lstm_cells, X, time_major=True, dtype=tf.float32)

#Unit2, OK
# X = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)
# X = tf.reshape(X, [-1, 5, 6])
# multi_rnn = []
# for i in range(2):
#     multi_rnn.append(tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units[i]))
# lstm_cells = tf.contrib.rnn.MultiRNNCell(multi_rnn)
# output, state = tf.nn.dynamic_rnn(lstm_cells, X, time_major=True, dtype=tf.float32)

# Unit3 *********ERROR***********
X = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)
X = tf.reshape(X, [-1, 5, 6])
# single_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=20) # same as below
lstm_cells = tf.contrib.rnn.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(num_units=20)] * 2)
output, state = tf.nn.dynamic_rnn(lstm_cells, X, time_major=True, dtype=tf.float32)

print(output)
print(state)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for var in tf.global_variables():
        print(var.op.name)
    output_run, state_run = sess.run([output, state])
   
 之前还真没注意到这个问题, 虽然一般都是多层的维度一致,但是都是写成 Unit2 这种形式.

第二个问题两个 Encoder 的 State 的融合, 并保持 State 类型 (LSTM/GRU)

import tensorflow as tf

def concate_rnn_states(num_layers, encoder_state_local, encoder_state_global):
    '''
    :param num_layers:
    :param encoder_fw_state:
    For LSTM:
    (LSTMStateTuple(c=<tf.Tensor 'encoder1/rnn/while/Exit_3:0' shape=(3, 20) dtype=float32>,
        h=<tf.Tensor 'encoder1/rnn/while/Exit_4:0' shape=(3, 20) dtype=float32>),
    LSTMStateTuple(c=<tf.Tensor 'encoder1/rnn/while/Exit_5:0' shape=(3, 20) dtype=float32>,
        h=<tf.Tensor 'encoder1/rnn/while/Exit_6:0' shape=(3, 20) dtype=float32>))
    For GRU:
    (<tf.Tensor 'encoder1/rnn/while/Exit_3:0' shape=(3, 20) dtype=float32>,
        <tf.Tensor 'encoder1/rnn/while/Exit_4:0' shape=(3, 20) dtype=float32>)
    :param encoder_bw_state: same as fw
    :return: tuple
    '''
    encoder_states = []
    for i in range(num_layers):
        if isinstance(encoder_state_local[i], tf.nn.rnn_cell.LSTMStateTuple):
            # for lstm
            encoder_state_c = tf.concat(values=(encoder_state_local[i].c, encoder_state_global[i].c), axis=1,
                                        name="concat_layer{}_state_c".format(i))
            encoder_state_h = tf.concat(values=(encoder_state_local[i].h, encoder_state_global[i].h), axis=1,
                                        name="concat_layer{}_state_h".format(i))
            encoder_state = tf.contrib.rnn.LSTMStateTuple(c=encoder_state_c, h=encoder_state_h)
        elif isinstance(encoder_state_local[i], tf.Tensor):
            # for gru and rnn
            encoder_state = tf.concat(values=(encoder_state_local[i], encoder_state_global[i]), axis=1,
                                      name='GruOrRnn_concat')

        encoder_states.append(encoder_state)
    return tuple(encoder_states)

num_units = [20, 20]

#Unit1
X = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)
X = tf.reshape(X, [-1, 5, 6])

with tf.variable_scope("encoder1") as scope:
    local_multi_rnn = [tf.nn.rnn_cell.GRUCell(num_units=units) for units in num_units]
    local_lstm_cells = tf.contrib.rnn.MultiRNNCell(local_multi_rnn)
    encoder_output_local, encoder_state_local = tf.nn.dynamic_rnn(local_lstm_cells, X, time_major=False, dtype=tf.float32)

with tf.variable_scope("encoder2") as scope:
    global_multi_rnn = [tf.nn.rnn_cell.GRUCell(num_units=units) for units in num_units]
    global_lstm_cells = tf.contrib.rnn.MultiRNNCell(global_multi_rnn)
    encoder_output_global, encoder_state_global = tf.nn.dynamic_rnn(global_lstm_cells, X, time_major=False, dtype=tf.float32)

print("concat output")
encoder_outputs = tf.concat([encoder_output_local, encoder_output_global], axis=-1)
print(encoder_output_local)
print(encoder_outputs)

print("concat state")
print(encoder_state_local)
print(encoder_state_global)
encoder_states = concate_rnn_states(2, encoder_state_local, encoder_state_global)
print(encoder_states)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TF多层 LSTM 以及 State 之间的融合 的相关文章

  • 如何列出检查点中的某些变量?

    我正在使用自动编码器 我的检查点包含网络的完整状态 即编码器 解码器 优化器等 我想玩弄编码 因此 在我的评估模式中 我只需要网络的解码器部分 如何从现有检查点中仅读取一些特定变量 以便我可以在另一个模型中重用它们的值 There s li
  • android:媒体记录器:启动失败:-38

    简介 如何检查录音是否已在其他应用程序的后台运行 详细信息 如果录音已在本机应用程序的后台运行 录音机 现在我已将录音作为我的应用程序中的功能之一 问题 当我尝试同时在我的应用程序中录制时 出现错误 E MediaRecorder star
  • 使用张量流理解 LSTM 模型进行情感分析

    我正在尝试使用 Tensorflow 学习 LSTM 模型进行情感分析 我已经经历了LSTM模型 http colah github io posts 2015 08 Understanding LSTMs 以下代码 create sent
  • C# Winforms:以编程方式显示按钮悬停状态

    我在 winform 上显示数字键盘来输入代码 我正在显示带有按钮的数字键盘 用户将仅使用键盘数字键盘来输入代码 密码 但当然你可以使用鼠标 如果我们使用鼠标单击按钮 我们会得到蓝色效果来显示悬停和按下状态 我在想我是否可以以某种方式以编程
  • Tensorflow 2.2.0 错误:[预测必须 > 0] [条件 x >= y 不满足元素方向:] 使用双向 LSTM 层时

    在处理命名实体识别任务时 我收到以下错误消息 tensorflow python framework errors impl InvalidArgumentError assertion failed predictions must be
  • 如何在应用程序退出前执行代码 flutter

    我想检测用户何时退出我的应用程序并执行一些代码 但我不知道如何执行此操作 我尝试使用这个包 https pub dev packages flutter lifecycle state https pub dev packages flut
  • 我不理解非确定性图灵机的概念[关闭]

    Closed 这个问题是无关 help closed questions 目前不接受答案 我不明白这个概念非确定性图灵机 我想我理解这个词非确定性算法 非确定性算法是一种可以在不同的情况下表现出不同行为的算法 运行 而不是确定性算法 所以该
  • 更改 ViewController 时如何保持 UISwitch 状态?

    当我从一个视图控制器移动到另一个视图控制器时 第一个控制器上的开关会自行重置并且不保留其状态 在查看其他控制器后返回时如何使其保存状态 以及如何让它在关闭应用程序后保存其状态 我查看了各种 stackOverflow 问题和回复以及苹果文档
  • Python - 基于 LSTM 的 RNN 需要 3D 输入?

    我正在尝试构建一个基于 LSTM RNN 的深度学习网络 这是尝试过的 from keras models import Sequential from keras layers import Dense Dropout Activatio
  • Tensorflow:如何使用dynamic_rnn从LSTMCell获取中间细胞状态(c)?

    默认情况下 函数dynamic rnn仅输出隐藏状态 称为m 对于每个时间点可以通过如下方式获得 cell tf contrib rnn LSTMCell 100 rnn outputs tf nn dynamic rnn cell inp
  • 将状态与 IO 操作结合起来

    假设我有一个状态单子 例如 data Registers Reg data ST ST registers Registers memory Array Int Int newtype Op a Op runOp ST gt ST a in
  • 如何为 keras lstm 输入重塑数据?

    我是 Keras 新手 我发现很难理解 LSTM 层输入数据的形状 Keras 文档表示输入数据应该是形状为 nb samples timesteps input dim 的 3D 张量 我有808信号 每个信号有22个通道和2000个数据
  • 在 Keras 中使用 Subtract 层

    我正在 Keras 中实现所描述的 LSTM 架构here http nlp cs rpi edu paper multilingualmultitask pdf 我认为我已经非常接近了 尽管我在共享层和特定语言层的组合方面仍然存在问题 这
  • 为什么我的 keras LSTM 模型陷入无限循环?

    我正在尝试构建一个小型 LSTM 它可以通过在现有 Python 代码上进行训练来学习编写代码 即使是垃圾代码 我已将数百个文件中的数千行代码连接到一个文件中 每个文件以
  • LSTM 错误:AttributeError:“tuple”对象没有属性“dim”

    我有以下代码 import torch import torch nn as nn model nn Sequential nn LSTM 300 300 nn Linear 300 100 nn ReLU nn Linear 300 7
  • 尝试理解 Pytorch 的 LSTM 实现

    我有一个包含 1000 个示例的数据集 其中每个示例都有5特征 a b c d e 我想喂7LSTM 的示例 以便它预测第 8 天的特征 a 阅读 nn LSTM 的 Pytorchs 文档 我得出以下结论 input size 5 hid
  • 重置流的状态

    我有一个问题与 stackoverflow 上的这个问题有点相似std cin clear 无法将输入流恢复到良好状态 https stackoverflow com questions 4960399 stdcin clear fails
  • 提交后清除 React 中的表单

    我试图在使用 Axios 创建表单提交后清除表单数据 消息处理良好 响应记录到页面 但每个文本字段中的数据在提交后仍保留在页面上 我尝试添加一个resetForm函数 我将表单设置回原来的空白状态 但这不起作用 import React C
  • Tensorflow 的 LSTM 输入

    I m trying to create an LSTM network in Tensorflow and I m lost in terminology basics I have n time series examples so X
  • FocusState Textfield 在工具栏 ToolbarItem 中不起作用

    让我解释一下 我有一个带有 SearchBarView 的父视图 我正在传递这样的焦点状态绑定 SearchBarView searchText object searchQuery searching object searching f

随机推荐

  • 6个常用的React组件库

    Ant Design 项目链接 xff1a Ant Design 包大小 xff08 来自 BundlePhobia xff09 xff1a 缩小后 1 2mB xff0c 缩小 43 gzip 压缩后 349 2kB xff0c 通过摇树
  • 大数据培训课程数据清洗案例实操-简单解析版

    数据清洗 xff08 ETL xff09 在运行核心业务MapReduce程序之前 xff0c 往往要先对数据进行清洗 xff0c 清理掉不符合用户要求的数据 清理的过程往往只需要运行Mapper程序 xff0c 不需要运行Reduce程序
  • 宋红康2023版Java视频发布

    1500万 43 播放量见证经典 xff0c 尚硅谷宋红康老师的Java入门视频堪称神作 xff0c 如今经典再次超级进化 xff0c 新版Java视频教程震撼来袭 xff01 开发环境全新升级 xff1a JDK17 43 IDEA202
  • Java消息队列:消息在什么时候会变成Dead Letter?

    在较为重要的业务队列中 xff0c 确保未被正确消费的消息不被丢弃 xff0c 通过配置死信队列 xff0c 可以让未正确处理的消息暂存到另一个队列中 xff0c 待后续排查清楚问题后 xff0c 编写相应的处理代码来处理死信消息 一 什么
  • Vue2和Vue3数据双向绑定原理的区别及优缺点(下篇)

    上篇我们讲到了Vue2的数据双向绑定原理 xff0c 如果你没有阅读上篇 xff0c 建议先阅读一下上篇中的内容 Vue2和Vue3数据双向绑定原理的区别及优缺点 xff08 上篇 xff09 在上篇中我们抛出了一个问题 xff1a 是不是
  • FlinkTable时间属性

    像窗口 xff08 在 Table API 和 SQL xff09 这种基于时间的操作 xff0c 需要有时间信息 因此 xff0c Table API 中的表就需要提供逻辑时间属性来表示时间 xff0c 以及支持时间相关的操作 一 处理时
  • kafka学习(1)

    目录 kafka是什么 xff1f 为什么要用kafka kafka的特点 kafka结构 Kafka Producer的Ack机制 kafka是什么 xff1f 收集nginx日志 xff0c 将nginx日志的关键字段进行分析 xff0
  • spss 因子分析

    是通过研究变量间的相关系数矩阵 xff0c 把这些变量间错综复杂的关系归结成少数几个综合因子 xff0c 并据此对变量进行分类的一种统计方法 xff0c 归结出的因子个数少于原始变量的个数 xff0c 但是他们又包含原始变量的信息 xff0
  • Hive 报错 Invalid column reference 列名

    两张表 当我执行 select m movieid m moviename substr m moviename 5 4 as years avg r rate as avgScore FROM t movie as m join t ra
  • 20数学建模C-中小微企业的信贷决策

    前言 源码文末获取 小编在 9 月份参加了今年的数学建模 xff0c 成绩怎么样不知道 xff0c 能有个成功参与奖就不错了哈哈 最近整理了一下 xff0c 写下这篇文章分享小编的思路 能力知识水平有限 xff0c 欢迎各位大佬前来指教 o
  • playwright 爬虫使用

    官方文档 xff1a Getting started Playwright Python 参考链接 xff1a 强大易用 xff01 新一代爬虫利器 Playwright 的介绍 目录 安装 基本使用 代码生成 AJAX 动态加载数据获取
  • kmeans聚类选择最优K值python实现

    来源 xff1a https www omegaxyz com 2018 09 03 k means find k 下面利用python中sklearn模块进行数据聚类的K值选择 数据集自制数据集 xff0c 格式如下 xff1a 维度为3
  • mysql构造页损坏

    构造页损坏 及修复方式可参考 gg gMysql页面crash问题复现 amp 恢复方法 阿里云开发者社区 也可通过 dd 命令进行构造 dd xff0c 命令参考 xff1a Linux dd 命令 菜鸟教程
  • mysql审计日志过滤sql功能

    审计日志功能是一个插件 xff0c 需要先安装插件才可以使用 过滤 sql 语句 xff0c 可以通过插件内核参数 audit log include commands 与 audit log exclude commands 参数设置 x
  • setDaemon python守护进程,队列通信子线程

    使用setDaemon 和守护线程这方面知识有关 xff0c 比如在启动线程前设置thread setDaemon True xff0c 就是设置该线程为守护线程 xff0c 表示该线程是不重要的 进程退出时不需要等待这个线程执行完成 这样
  • 中文与 \u5927\u732a\u8e44\u5b50 这一类编码互转

    了解更多关注微信公众号 木下学Python 吧 a 61 39 大猪蹄子 39 a 61 a encode 39 unicode escape 39 print a 运行结果 xff1a b 39 u5927 u732a u8e44 u5b
  • python字典删除键值对

    https blog csdn net uuihoo article details 79496440
  • 计算机网络(4)传输层

    目录 小知识点 xff1a 三次握手 xff1a 状态 xff1a tcpdump xff1a 一 xff1a 命令介绍 xff1a 二 xff1a 命令选项 xff1a tcpdump的表达式 xff1a 使用python扫描LAN工具
  • MSE 治理中心重磅升级-流量治理、数据库治理、同 AZ 优先

    作者 xff1a 流士 本次 MSE 治理中心在限流降级 数据库治理及同 AZ 优先方面进行了重磅升级 xff0c 对微服务治理的弹性 依赖中间件的稳定性及流量调度的性能进行全面增强 xff0c 致力于打造云原生时代的微服务治理平台 前情回
  • TF多层 LSTM 以及 State 之间的融合

    第一是实现多层的LSTM的网络 第二是实现两个LSTM的state的concat操作 分析 state 的结构 对于第一个问题 之前一直没有注意过 看下面两个例子 在这里插入代码片 import tensorflow as tf num u