张量流/tflearn 输入形状

2024-01-01

我正在尝试创建一个 lstm-rnn 来生成音乐序列。训练数据是大小为 4 的向量序列,表示一些要训练的歌曲中每个音符的各种特征(包括 MIDI 音符)。

从我的阅读来看,我想要做的是对于每个输入样本,输出样本是下一个大小为 4 的向量(即,它应该尝试在给定当前音符的情况下预测下一个音符,并且由于LSTM 融合了之前样本的知识)。

我正在使用 tflearn,因为我对 RNN 还很陌生。我有以下代码

net = tflearn.input_data(shape=[None, seqLength, 4])
net = tflearn.lstm(net, 128, return_seq=True)
net = tflearn.dropout(net, 0.5)
net = tflearn.lstm(net, 128)
net = tflearn.dropout(net, 0.5)
net = tflearn.fully_connected(net, 4, activation='softmax')
net = tflearn.regression(net, optimizer='adam',
                     loss='mean_square')

# Training
model = tflearn.DNN(net, tensorboard_verbose=3)
model.fit(trainX, trainY, show_metric=True, batch_size=128)

在这段代码之前,我已将 trainX 和 trainY 拆分为长度为 20 的序列(任意,但我在某处读到,对这样的序列进行训练是实现此目的的好方法)。

这似乎很好,但我收到错误 ValueError: Cannot feed value of shape (128, 16, 4) for Tensor u'TargetsData/Y:0', which has shape '(?, 4)'

SO:到目前为止,我的假设是输入形状 [None, seqLength, 4] 对 TF [batchLength(由 tflearn 顺序输入)、序列长度、样本特征长度] 表示。我不明白的是为什么它说输出形状错误?我对数据序列分割的假设是否错误?当我尝试输入所有数据而不拆分为序列时,输入形状为 [None, 4],TF 告诉我 LSTM 层需要至少具有 3 个维度的输入形状。

我无法弄清楚输入和输出的形状应该是什么。感觉这应该是一件简单的事情——我有一组向量输入序列,我希望网络尝试预测序列中的下一个。网上很少有不具备相当高级知识水平的内容,所以我遇到了困难。非常感谢任何人都能提供的任何见解!


我解决了这个问题,所以我在这里为遇到同样问题的人写下答案。这是基于对这些网络如何工作的误解,但这是我读过的大多数教程中假定的知识,因此其他初学者可能不清楚。

LSTM 网络对于这些情况非常有用,因为它们可以考虑输入历史记录。向 LSTM 提供历史记录的方式是通过排序,但每个序列仍然会导致单个输出数据点。因此输入必须是 3D 形状,而输出只是 2D 形状。

给定整个序列和所需的历史长度,我将输入拆分为历史长度序列和单个输出向量。这解决了我的形状问题。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

张量流/tflearn 输入形状 的相关文章

随机推荐

  • PayPal 400 错误请求,更具体吗?

    有没有办法获得比 400 bad request 更具体的 PayPal 错误 我看到有人在做这样的事情 if ex InnerException is ConnectionException Response Write Connecti
  • 从 C++ 代码运行 python 脚本并在 C++ 中使用 python 输出

    我有一个 C 程序 它已经开发得相当成熟 可以完成自己的工作 现在我们想向它添加一个附加功能 我们认为在 Python 中实现所述功能 然后使用来自 C 的所需输入调用该 python当需要时这将是最好的方法 因为它将它们分开并允许我们也从
  • 将 pandas 日期列转换为经过的秒数

    我有一个多列的 pandas 数据框 其中有一列 datetime64 ns 数据 时间采用 HH MM SS 格式 如何将这一列日期转换为一列经过的秒数 就像如果时间以秒为单位表示 10 00 00 则为 36000 秒应该采用 floa
  • 为什么使用 PatternSynonyms 会触发非详尽匹配警告?

    我正在跟进这个答案 https stackoverflow com a 31106916 12162258学习如何进行模式匹配Sequence https hackage haskell org package containers 0 6
  • 通过 LAN 网络从另一台计算机访问 localhost (xampp) - 如何?

    我刚刚在家里建立了 wi fi 网络 我的所有文件都在我的台式计算机 192 168 1 56 上 并且想从另一台计算机 192 168 1 2 访问那里的本地主机 在我的桌面上 我可以通过正常的 http localhost 访问 loc
  • 在上面的文件夹中包含 .php 文件

    我收到此错误 Warning include config php function include failed to open stream No such file or directory in home soizastu publ
  • 计算目录中的文件数

    我想计算目录中的文件数 我在 QDir 类中使用了 count 方法 但它总是返回文件数加二 它为什么要做这个工作 谢谢 你应该使用flags QDir Filters with QDir NoDotAndDotDot
  • 无法获取访问令牌:通过 OAuth2Bearer 访问 S/4HANA 时未找到有效的 JWT 承载者

    我通过以下方式生成了项目 mvn archetype generate DarchetypeGroupId com sap cloud s4hana archetypes DarchetypeArtifactId scp cf spring
  • 读取 csv 文件 - 一个列表中的每个字符

    我对 python 相当陌生 确实需要一些帮助 我现在还没有找到任何对我有帮助的东西 我想将 csv 文件读取到list 但不幸的是我的输出不符合预期 而不是有一个像这样的列表 Weiz 61744 Deutschlandsberg 564
  • 无法在 PHP 中加载 LDAP 函数

    当尝试使用ldap connect http php net ldap connect 我收到此错误 致命错误 调用未定义的函数 ldap connect 我重新编译了 php 启用了 LDAP apache 模块 并且也编辑了 php i
  • J Oliver EventStore V2.0 问题

    我正在着手使用 CQRS 实施一个项目 并打算使用 J Oliver EventStore V2 0 作为我的事件持久化引擎 1 在文档中 ExampleUsage cs在 BuildSerializer 中使用了3个序列化器 我想这只是为
  • Java程序中“无法解析驱动程序”

    我是编程世界的新手 所以我不知道如何解决这个问题 Test public void LoginEmail driver findElement By id email button sendKeys email protected cdn
  • 安全 Web 服务 (NTLM) - Jmeter

    我正在尝试使用 Jmeter 测试 Web 服务 Web 服务受 NTLM 身份验证 Windows 保护 我可以使用加载 WSDLWebService SOAP Request采样器 目前 仅当我将代理服务器与本文中提到的 BurpSui
  • 定义宏中的括号

    是什么时候必要的将定义宏的整个 右 表达式放在括号中 如果我做类似的事情 define SUM x y x y 我必须将正确的表达式放入括号中 因为 在 C 中的优先级较低 如果我在以下上下文中使用它 它将不起作用SUM x y 5U 如果
  • 实现SelectableDataModel

    XHTML 方面
  • 在python中将rgb转换为lab的快速方法

    有没有在Python3中使用D50 sRGB将RGB转换为LAB的快速方法 Python 色彩数学 https github com gtaylor python colormath太慢了 skimage http scikit image
  • 使用 Directory.Build.Prop for .NET Framework 添加包

    我有一个 Visual Studio 解决方案 我尝试在所有项目中使用构建 prop 文件添加代码分析器 我的项目依赖于 NET Core 以及框架 我有以下 Directory Build Prop 文件
  • 为什么我收到有关 Java 实用程序类的警告

    我正在学习 Java 和 OOPS 在 Eclipse 中编写基本的 Hello World 时 我看到一个黄色三角形告诉我 实用程序类不应具有公共或默认构造函数 我无法理解为什么会发生这种情况 这意味着什么 我做错了什么 class He
  • Rails 使用正在运行的构建器编写 xml

    我想在我的网站中使用 hipay 所以我需要在操作中生成一个 xml 然后通过帖子发送到 hipay 网站 我的问题是 我如何动态创建 xml 然后在同一操作中通过邮寄发送此 xml 我的控制器中的示例 def action generat
  • 张量流/tflearn 输入形状

    我正在尝试创建一个 lstm rnn 来生成音乐序列 训练数据是大小为 4 的向量序列 表示一些要训练的歌曲中每个音符的各种特征 包括 MIDI 音符 从我的阅读来看 我想要做的是对于每个输入样本 输出样本是下一个大小为 4 的向量 即 它