Tensorflow中GRU单元的解释?

2024-04-08

以下是 Tensorflow 的代码GRUCell单元显示了当先前的隐藏状态与序列中的当前输入一起提供时获得更新的隐藏状态的典型操作。

  def __call__(self, inputs, state, scope=None):
    """Gated recurrent unit (GRU) with nunits cells."""
    with vs.variable_scope(scope or type(self).__name__):  # "GRUCell"
      with vs.variable_scope("Gates"):  # Reset gate and update gate.
        # We start with bias of 1.0 to not reset and not update.
        r, u = array_ops.split(1, 2, _linear([inputs, state],
                                             2 * self._num_units, True, 1.0))
        r, u = sigmoid(r), sigmoid(u)
      with vs.variable_scope("Candidate"):
        c = self._activation(_linear([inputs, r * state],
                                     self._num_units, True))
      new_h = u * state + (1 - u) * c
return new_h, new_h

但我没有看到任何weights and biases这里。 例如我的理解是r and u需要将权重和偏差与当前输入和/或隐藏状态相乘以获​​得更新的隐藏状态。

我写了一个gru单元如下:

def gru_unit(previous_hidden_state, x):
    r  = tf.sigmoid(tf.matmul(x, Wr) + br)
    z  = tf.sigmoid(tf.matmul(x, Wz) + bz)
    h_ = tf.tanh(tf.matmul(x, Wx) + tf.matmul(previous_hidden_state, Wh) * r)
    current_hidden_state = tf.mul((1 - z), h_) + tf.mul(previous_hidden_state, z)
    return current_hidden_state

这里我明确地使用了权重Wx, Wr, Wz, Wh和偏见br, bh, bz等来获取更新的隐藏状态。这些权重和偏差是训练后学习/调整的。

如何利用 Tensorflow 的内置功能GRUCell达到与上面相同的结果?


它们在那里,您只是在代码中看不到它们,因为 _线性 函数添加了权重和偏差。

r, u = array_ops.split(1, 2, _linear([inputs, state],
                                             2 * self._num_units, True, 1.0))

...

def _linear(args, output_size, bias, bias_start=0.0, scope=None):
  """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.

  Args:
    args: a 2D Tensor or a list of 2D, batch x n, Tensors.
    output_size: int, second dimension of W[i].
    bias: boolean, whether to add a bias term or not.
    bias_start: starting value to initialize the bias; 0 by default.
    scope: VariableScope for the created subgraph; defaults to "Linear".

  Returns:
    A 2D Tensor with shape [batch x output_size] equal to
    sum_i(args[i] * W[i]), where W[i]s are newly created matrices.

  Raises:
    ValueError: if some of the arguments has unspecified or wrong shape.
  """
  if args is None or (nest.is_sequence(args) and not args):
    raise ValueError("`args` must be specified")
  if not nest.is_sequence(args):
    args = [args]

  # Calculate the total size of arguments on dimension 1.
  total_arg_size = 0
  shapes = [a.get_shape().as_list() for a in args]
  for shape in shapes:
    if len(shape) != 2:
      raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes))
    if not shape[1]:
      raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes))
    else:
      total_arg_size += shape[1]

  # Now the computation.
  with vs.variable_scope(scope or "Linear"):
    matrix = vs.get_variable("Matrix", [total_arg_size, output_size])
    if len(args) == 1:
      res = math_ops.matmul(args[0], matrix)
    else:
      res = math_ops.matmul(array_ops.concat(1, args), matrix)
    if not bias:
      return res
    bias_term = vs.get_variable(
        "Bias", [output_size],
        initializer=init_ops.constant_initializer(bias_start))
  return res + bias_term
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow中GRU单元的解释? 的相关文章

  • 在 Chaquopy 中转换数组和张量

    我该怎么做呢 我看到你的帖子说你可以将 java 对象传递给 Python 方法 但这不适用于 numpy 数组和 TensorFlow 张量 以下以及其各种变体是我尝试过的 但没有成功 double anchors new double
  • Tensorflow 的 LSTM 输入

    I m trying to create an LSTM network in Tensorflow and I m lost in terminology basics I have n time series examples so X
  • 在tensorflow .ckpt文件中使用预训练模型

    我有一个 ckpt 文件 我只想得到 cnn 的权重 我已经从 ckpt 检查点文件中进行了训练 inception resnet v2 2016 08 30 import tensorflow as tf saver tf train S
  • 如何在 Caffe 的网络中出现多次损失?

    如果我在网络中定义多个损失层 从这些末端到网络的开头是否会发生多个反向传播 我的意思是 他们真的是这样工作的吗 假设我有这样的事情 Layer1 Layer2 Layer n Layer cls1 bottom layer n top cl
  • 如何创建 Keras 层来执行 4D 卷积 (Conv4D)?

    看起来tf nn convolution应该能够进行 4D 卷积 但我无法成功创建 Keras 层来使用此函数 我尝试过使用 KerasLambda层来包裹tf nn convolution功能 但也许其他人有更好的主意 我想利用数据的高维
  • 如何将one-hot向量转换为多标签?

    我有一项多分类任务 并且我得到了像这样的单热类型预测 0 1 1 0 1 0 1 0 1 我希望将这个单热向量转换为标签 例如 1 2 1 0 2 我已经尝试过 tf argmax 但它不起作用 那么我该如何处理呢 使用列表理解 oheLi
  • 这可能是因为 cuDNN 初始化失败,因此请尝试查看上面是否打印了警告日志消息。 [操作:Conv2D]

    我在 anaconda 中安装了 TensorFlow GPU 2 0 当我安装它并导入包 然后运行我的 CNN 模型时 它工作正常 但当我尝试运行训练模型时 出现错误 这是我的错误报告 Epoch 1 50 UnknownError Tr
  • Keras 中的损失函数和度量有什么区别? [复制]

    这个问题在这里已经有答案了 我不清楚 Keras 中损失函数和指标之间的区别 该文档对我没有帮助 损失函数用于优化您的模型 这是优化器将最小化的函数 指标用于判断模型的性能 这仅供您查看 与优化过程无关
  • LSTM 批次与时间步

    我按照 TensorFlow RNN 教程创建了 LSTM 模型 然而 在这个过程中 我对 批次 和 时间步长 之间的差异 如果有的话 感到困惑 并且我希望得到帮助来澄清这个问题 教程代码 见下文 本质上是根据指定数量的步骤创建 批次 wi
  • 如何在 Keras 中将多个数据集与一个模型一起使用?

    我正在尝试使用 LSTM 网络通过 Keras 和 Tensorflow 进行外汇预测 我当然希望它能够在很多天的交易中进行训练 但要做到这一点 我必须给它提供具有大跳跃和无运动阶段的连续数据 当市场收盘时 这并不理想 因为它变得由于这些跳
  • 用于分布式计算的 Tensorflow 设置

    任何人都可以提供有关如何设置张量流以在网络上的许多CPU上工作的指导吗 到目前为止 我发现的所有示例最多只使用一个本地盒子和多个 GPU 我发现我可以在 session opts 中传递目标列表 但我不确定如何在每个盒子上设置张量流来侦听网
  • 错误:tensorflow:无法匹配检查点的文件

    我正在训练一个张量流模型 在每个时期之后我都会保存模型状态并腌制一些数组 到目前为止 我的模型执行了 2 个纪元 并且保存状态的文件夹包含以下文件 checkpoint model e knihy preprocessed txt e0 c
  • Keras ImageDataGenerator 相当于 csv 文件

    我在文件夹中排序了一堆数据 如下图所示 我需要构建一个 DataIterator 以便将数据放入神经网络模型中 当数据是图像时 我找到了很多例子来解决这个问题 使用 Keras 类图像数据生成器及其方法流自目录 但当数据是 csv 结构时则
  • Keras ZeroDivisionError:整数除法或以零为模

    我正在尝试使用 Keras 和 Tensorflow 实现卷积神经网络 我有以下代码 from keras models import Sequential from keras layers import Conv2D MaxPoolin
  • 在 Tensorflow 对象检测 API 中绘制验证损失

    我正在使用 Tensorflow 对象检测 API 来检测和定位图像中的一类对象 为了这些目的 我使用预先训练的faster rcnn resnet50 coco 2018 01 28 model 我想在训练模型后检测拟合不足 过度拟合 我
  • 在监督分类中,使用partial_fit() 的MLP 比使用fit() 的表现更差

    我正在使用的学习数据集是灰度图像flatten让每个像素代表一个单独的样本 第二张图像在训练后将被逐像素分类Multilayer perceptron MLP 前一个分类器 我遇到的问题是MLP当它一次接收到所有训练数据集时表现更好 fit
  • 在不同的 GPU 上同时训练多个 keras/tensorflow 模型

    我想在 Jupyter Notebook 中同时在多个 GPU 上训练多个模型 我正在使用 4GPU 的节点上工作 我想将一个 GPU 分配给一个模型并同时训练 4 个不同的模型 现在 我通过 例如 为一台笔记本选择 GPU import
  • scikit-learn 和tensorflow 有什么区别?可以一起使用它们吗?

    对于这个问题我无法得到满意的答案 据我了解 TensorFlow是一个数值计算库 经常用于深度学习应用 而Scikit learn是一个通用机器学习框架 但它们之间的确切区别是什么 TensorFlow 的目的和功能是什么 我可以一起使用它
  • 可视化 TFLite 图并获取特定节点的中间值?

    我想知道是否有办法知道 tflite 中特定节点的输入和输出列表 我知道我可以获得输入 输出详细信息 但这不允许我重建发生在Interpreter 所以我要做的是 interpreter tf lite Interpreter model
  • 对于只有 10000 个单词的字典来说,真正需要什么嵌入层 output_dim?

    我正在训练一个 RNN 其单词特征集非常少 大约 10 000 个 我计划在添加 RNN 之前从嵌入层开始 但我不清楚真正需要什么维度 我知道我可以尝试不同的值 32 64 等 但我宁愿先有一些直觉 例如 如果我使用 32 维嵌入向量 则每

随机推荐

  • 如何在“border-*”属性中使用百分比?

    我有使用 Twitter Bootstrap 3 的代码 nav with right arrow 我使用它创建的border 特性 但是如果我在中使用很长的文本right arrow 它不会扩展 如果我使用百分比 代码将无法工作 示例Js
  • 复制virtualenv文件夹后如何在Cygwin中激活virtualenv

    完整的初学者在这里 尝试构建一个 Flask Web 应用程序 使用 Windows 8 在 Cygwin 中激活我的 python virtualenv 时遇到一些问题 到目前为止我一直在使用 git shell 没有任何问题 我将文件夹
  • React.js:将默认值设置为 prop

    我制作了这个组件来创建一个简单的按钮 class AppButton extends Component setOnClick if this props onClick typeof this props onClick function
  • 在 ASP.NET MVC 3 应用程序中扩展 Windows 身份验证

    经过大量谷歌搜索并阅读了有关如何在 ASP NET 应用程序中管理混合模式身份验证的几种解决方案后 我仍然没有适合我的问题的解决方案 我必须为一堆不同的用户组实现一个 Intranet 应用程序 到目前为止 我一直使用 Windows 身份
  • 无法在 ubuntu 19.04 上安装 libzmq3-dev

    我正在尝试安装libzmq3 dev on 乌班图19 04 使用命令 sudo apt install build essential libsocketcan dev libzmq3 dev 我收到消息 gt Some packages
  • Pentaho Spoon 工具转换顺序

    我正在尝试设计一个 ETL 结构 但我陷入了以下步骤 正如你所看到的 我有 3 个步骤 每个步骤都有一个FK上一步的值 例如TABLE3有一个列外键约束这表明PK值在TABLE2 and TABLE2与 具有相同的关系TABLE1 问题是
  • 如何在我的 Maven 项目中正确包含“org.apache.catalina.filters.SetCharacterEncodingFilter”过滤器?

    我使用 Maven 3 3 和 JBoss 7 1 3 Final Java 6 我想在我的 Web 应用程序中包含一个过滤器 以便所有传入请求数据都将编码为 UTF 8 所以我将其添加到我的 web xml 文件中
  • Powershell CheckedListBox 检查是否在字符串/数组中

    我已经开始学习 Powershell 但在花了几个小时解决一个问题后陷入困境 我可以找到除 Powershell 之外的多种语言的解决方案 我需要对 CheckedListBox 中的每个项目进行检查 该项目与名为的分号分隔字符串中的任何值
  • WPF 中 WinForms TextBox.Validating 事件的等效项

    在 WinForms 中 我可以处理 Validated 事件 以便在用户更改 TextBox 中的文本后执行某些操作 与 TextChanged 不同 Validated 不会在每次字符更改时触发 它仅在用户完成后触发 WPF 中是否有任
  • 我到底必须在 viewDidUnload 中做什么?

    我倾向于在 dealloc 中释放我的东西 现在 iPhone OS 3 0 引入了这个有趣的 viewDidUnload 方法 他们说 释放所有保留的子视图 主要视图 例如自我我的出口 零 因此 当视图控制器的视图从内存中启动时 view
  • Pandas - 按一列分组,按另一列排序,从第三列获取值

    我想采用 pandas 数据框 按一列对其进行分组 按另一列对其进行排序 并从第三列中获取第一个元素并填充原始数据框 这是我原来的 df 我将按 col 1 分组 按 col 2 升序 排序 并从 col 3 中取出第一个元素并用结果填充
  • 对角线穿过视图

    根据某些条件 我必须对角剪切列表单元格 为此 我使用以下代码制作了对角线可绘制图像 对角线 xml
  • 沿多边形边界随机采样点

    I am trying to randomly sample points on a polygon boundary made of arbitrary number of points The polygon consist of a
  • C++中的默认参数

    考虑以下 int foo int x int z 0 int foo int x int y int z 0 如果我像这样调用这个函数 foo 1 2 编译器如何知道使用哪一个 它不会 因此这个例子不会编译干净 它会给你一个编译错误 它会给
  • Cardview 涟漪效应不起作用

    最小 SDK 为 21 当我单击回收器适配器中的卡片视图时 不会发生连锁反应 只会转到下一个屏幕 recyclerview 位于片段内
  • JDBC 无法加载数据源的工厂类

    我已经遇到这个问题好几天了 但没有设法解决它 我使用的是 tomcat 7 0 我完全无法连接 mysql 数据库 我正在编写的应用程序是一个使用eclipse IDE的jsp动态网站 TomCat 7 启动时出现此错误 WARNING F
  • 为什么 Z3 在这个简单的输入上返回“未知”?

    这是输入 set option auto config false set option mbqi false declare sort T6 declare sort T7 declare fun set23 T7 T7 Bool ass
  • 在 Aptana Studio 3 中禁用 CSS 验证

    有人知道如何使用 Aptana Studio 3 禁用 CSS 验证吗 在版本 3 0 4 中 即使完全完成后 警告仍然存在禁用 W3C CSS 验证器 https stackoverflow com questions 6652793 h
  • 在最近的 JVM 中,不可见引用仍然是一个问题吗?

    我正在读书Java 平台性能 http java sun com docs books performance 1st edition html JPAppGC fm html 遗憾的是 自从我最初提出这个问题以来 该链接似乎已经从互联网上
  • Tensorflow中GRU单元的解释?

    以下是 Tensorflow 的代码GRUCell单元显示了当先前的隐藏状态与序列中的当前输入一起提供时获得更新的隐藏状态的典型操作 def call self inputs state scope None Gated recurrent