张量流:简单 LSTM 网络的共享变量错误

2024-04-28

我正在尝试构建一个最简单的 LSTM 网络。只是想让它预测序列中的下一个值np_input_data.

import tensorflow as tf
from tensorflow.python.ops import rnn_cell
import numpy as np

num_steps = 3
num_units = 1
np_input_data = [np.array([[1.],[2.]]), np.array([[2.],[3.]]), np.array([[3.],[4.]])]

batch_size = 2

graph = tf.Graph()

with graph.as_default():
    tf_inputs = [tf.placeholder(tf.float32, [batch_size, 1]) for _ in range(num_steps)]

    lstm = rnn_cell.BasicLSTMCell(num_units)
    initial_state = state = tf.zeros([batch_size, lstm.state_size])
    loss = 0

    for i in range(num_steps-1):
        output, state = lstm(tf_inputs[i], state)
        loss += tf.reduce_mean(tf.square(output - tf_inputs[i+1]))

with tf.Session(graph=graph) as session:
    tf.initialize_all_variables().run()

    feed_dict={tf_inputs[i]: np_input_data[i] for i in range(len(np_input_data))}

    loss = session.run(loss, feed_dict=feed_dict)

    print(loss)

解释器返回:

ValueError: Variable BasicLSTMCell/Linear/Matrix already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at:
    output, state = lstm(tf_inputs[i], state)

我做错了什么?


致电给lstm here:

for i in range(num_steps-1):
  output, state = lstm(tf_inputs[i], state)

除非您另有说明,否则每次迭代都会尝试创建具有相同名称的变量。您可以使用以下方法执行此操作tf.variable_scope

with tf.variable_scope("myrnn") as scope:
  for i in range(num_steps-1):
    if i > 0:
      scope.reuse_variables()
    output, state = lstm(tf_inputs[i], state)     

第一次迭代创建代表 LSTM 参数的变量,以及每次后续迭代(在调用reuse_variables)只会在范围内按名称查找它们。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

张量流:简单 LSTM 网络的共享变量错误 的相关文章

随机推荐

  • 如何向所有用户授予团队项目访问权限?

    在我们的组织中 我们有一些项目 根据政策 向所有有权访问 TFS 的开发人员 QA 和项目经理开放 在包含超过 150 个团队项目和 500 多个有效用户的团队项目集合中 我们如何轻松地将所有有效 TFS 用户添加到某个特定项目的 读者 组
  • 在文件之间共享变量(没有全局变量)

    据我了解 使用创建的变量let在 Javascript 中不能是全局的 我thought这意味着该变量仅存在于该特定文件中 然而 当我做一个简单 人为的例子时 A js let a 5 B js console log a 索引 html
  • 字符串到 CLLocation 纬度和经度

    我有两个表示纬度和经度的字符串 例如 56 6462520 我想将其分配给 CLLocation 对象以与我当前的位置进行比较 我尝试了以下代码 但只收到错误 CLLocation LocationAtual CLLocation allo
  • Objective-C 中“@public”是什么意思?

    读完一篇后关于 private 的问题 https stackoverflow com questions 844658 what does private mean in objective c我明白这是如何运作的 但是 由于所有变量都默
  • 操作系统如何知道缺失页面的磁盘地址?

    分页充当虚拟地址空间和物理地址空间之间的间接层 给定一个地址 操作系统 OS 内存管理单元 MMU 将其转换为主内存位置 我的问题是 主内存中不存在该页面的情况 操作系统如何知道在磁盘上哪里可以找到该页面 它在哪里存储1的信息 它不存储在页
  • 如何使用 ASP.NET 访问 Facebook 广告 API

    我希望访问使用适用于 NET 的 FaceBook 工具包的 FaceBook 广告 API 我在 codeplex com 中找到的 希望访问ads estimateTargetingStats尤其 Facebook 广告 API 详情
  • Html 5 音频标签自定义控件?

    我觉得我在这里服用了疯狂的药丸 因为我不知道如何使用自定义控件渲染 html 5 音频标签 到目前为止我有这个 html 它工作没有问题
  • 如何在 iPython 笔记本中保存单元格的输出

    我希望能够保存 iPython 笔记本的文本输出cell到磁盘上的文件中 我有 2 个额外的要求 要求 能够重新运行单元并用最新的内容覆盖我的输出 还显示笔记本内的输出 我已经弄清楚如何使用 capture将 iPython 笔记本的单元格
  • Swift Alamofire + Promise 捕获

    伙计们 除了catch之外 以下工作正常 xcode错误与expected member name following 这是使用 PromiseKit 进行承诺的正确方法吗 欢迎所有建议 谢谢 IBAction func loginButt
  • 使用 ProcessBuilder 运行 msys.bat

    我正在尝试使用 ProcessBuilder 在 java 中运行 msys bat 当我使用程序运行 bat 文件时 出现以下错误 找不到 rxvt exe 或 sh exe 二进制文件 正在中止 按任意键继续 这是代码 ProcessB
  • 当字符串和类都是引用类型时

    这是我上次面试的情况 问题 字符串存储在哪里 Answer 堆因为它是引用类型 问题 解释一下下面的代码 static void Main string args string one test string two one one one
  • 增加 C++ 程序 CPU 使用率

    我有一个用 C 编写的程序 每秒运行多个 for 循环 而不使用任何会使其因任何原因等待的东西 它始终使用 2 10 的 CPU 有没有什么方法可以强制它使用更多的CPU并进行更多的计算而不使程序变得更复杂 此外 我在 Windows 计算
  • 在当前元素的 onchange 上发送 $(this)

    我有这个html
  • 我可以在 iTunes Connect 中恢复到之前版本的应用程序吗?

    我在App Store中有应用程序 我提交了1 1版本 在Apple审核 批准和发布后 我注意到有一个明显的重大错误 所以我不得不从App Store暂停我的应用程序 我提交了新版本 1 2 您知道审核和发布需要 5 7 天 在新版本发布期
  • 如何使用 ggplot 绘制反向(互补)ecdf?

    我目前使用 stat ecdf 来绘制累积频率图 这是我使用的代码 cumu plot lt ggplot house total year aes download speed colour ISP stat ecdf size 1 但是
  • 获取 HTML 元素相对于窗口的边界框的正确方法是什么?

    我正在编写一个 Firefox 扩展 我试图将其限制为仅 XUL Javascript 无 XPCOM 当我得到一个mouseover对于 HTML 元素的事件 我需要获取其在 Windows 坐标系中的边界框 即内置 XUL 文档 bro
  • 你能在 MS Windows 上用 Python 将 stdin 作为文件打开吗?

    在 Linux 上 我使用 subbprocess Popen 来运行应用程序 该应用程序的命令行需要输入文件的路径 我了解到我可以将路径 dev stdin 传递到命令行 然后使用 Python 的 subproc stdin write
  • Google Spreadsheet Api 结构化查询语法的官方参考

    我正在寻找用于创建的查询语法的官方参考结构化查询用于请求 Google Spreadsheet API 中的行 如所讨论的here https developers google com google apps spreadsheets s
  • R 识别数据框列中的文本字符串

    我的数据框的一列包含单词和短语 我正在尝试为此列中具有特定文本字符串的字段创建一个虚拟变量 例如 kite cars 箱形风筝 模型车 我喜欢飞翔的风筝 世界汽车 myvector lt c kite cars box kites mode
  • 张量流:简单 LSTM 网络的共享变量错误

    我正在尝试构建一个最简单的 LSTM 网络 只是想让它预测序列中的下一个值np input data import tensorflow as tf from tensorflow python ops import rnn cell im