尝试理解 Pytorch 的 LSTM 实现

2024-04-19

我有一个包含 1000 个示例的数据集,其中每个示例都有5特征(a、b、c、d、e)。我想喂7LSTM 的示例,以便它预测第 8 天的特征 (a)。

阅读 nn.LSTM() 的 Pytorchs 文档,我得出以下结论:

input_size = 5
hidden_size = 10
num_layers = 1
output_size = 1

lstm = nn.LSTM(input_size, hidden_size, num_layers)
fc = nn.Linear(hidden_size, output_size)

out, hidden = lstm(X)  # Where X's shape is ([7,1,5])
output = fc(out[-1])

output  # output's shape is ([7,1])

根据文档:

nn.LSTM 的输入是“输入形状(seq_len、批次、input_size


当我将您的代码扩展为完整的示例时——我还添加了一些可能有帮助的注释——我得到以下信息:

import torch
import torch.nn as nn

input_size = 5
hidden_size = 10
num_layers = 1
output_size = 1

lstm = nn.LSTM(input_size, hidden_size, num_layers)
fc = nn.Linear(hidden_size, output_size)

X = [
    [[1,2,3,4,5]],
    [[1,2,3,4,5]],
    [[1,2,3,4,5]],
    [[1,2,3,4,5]],
    [[1,2,3,4,5]],
    [[1,2,3,4,5]],
    [[1,2,3,4,5]],
]

X = torch.tensor(X, dtype=torch.float32)

print(X.shape)         # (seq_len, batch_size, input_size) = (7, 1, 5)
out, hidden = lstm(X)  # Where X's shape is ([7,1,5])
print(out.shape)       # (seq_len, batch_size, hidden_size) = (7, 1, 10)
out = out[-1]          # Get output of last step
print(out.shape)       # (batch, hidden_size) = (1, 10)
out = fc(out)          # Push through linear layer
print(out.shape)       # (batch_size, output_size) = (1, 1)

这对我来说很有意义,考虑到你batch_size = 1 and output_size = 1(我假设,你正在做回归)。我不知道你在哪里output.shape = (7, 1)来自。

您确定您的X有正确的尺寸吗?你创建了吗nn.LSTM也许与batch_first=True?有很多小东西可以潜入。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

尝试理解 Pytorch 的 LSTM 实现 的相关文章

随机推荐

  • 带有 vararg observables 的 RxJava zip

    当我们确切地知道有多少个具有确切类型的可观察量并且我们想要压缩时 我们会这样做 Observable
  • JetBrains IDE 启动时出错:应用程序无法正确启动 (0xc000007b)

    我遇到了这个错误 但在重新安装 IDE 两次后几乎找不到解决方案 甚至我安装了 多合一运行时 但这也无济于事 因为我认为问题最初是在我更改了 Windows Defender 设置中的一些设置后开始的然后尝试重置它们 但肯定其他人报告了这个
  • lambda:通过引用捕获 const 引用是否应该产生未定义的行为?

    我刚刚在代码中发现了一个令人讨厌的错误 因为我通过引用捕获了对字符串的 const 引用 当 lambda 运行时 原始字符串对象已经消失了 引用的值是空的 而目的是它包含原始字符串的值 因此出现了错误 让我困惑的是 这并没有在运行时引发崩
  • BigInteger 数字的实现和性能

    我用 C 编写了一个 BigInteger 类 它应该能够对任何大小的所有数字进行运算 目前 我正在尝试通过比较现有算法并测试它们最适合哪些位数来实现非常快速的乘法方法 但我遇到了非常意外的结果 我尝试进行 20 次 500 位数字的乘法
  • Inflector.Net 的替代品

    我想在我的项目中使用 inflector net 刚刚谷歌了一下 好像已经消失了 http andrewpeters net inflectornet http andrewpeters net inflectornet 还有其他选择吗 编
  • Chrome 开发工具中的 __puppeteer_evaluation_script__ 为空

    Puppeteer 版本 9 0 0 将调试器放入 JavaScript 代码并启动 puppeteer 时 chrome 开发工具中的源代码为空 使用 Node 运行脚本 scripts test echo Error no test s
  • Docker 和 Python virtualenv 有什么区别?

    根据我对Docker的理解 它是一个用于虚拟环境的工具 用他们的行话来说 这称为 容器化 这或多或少就是 Python 的 virtualenv 所做的事情 但是 您可以使用 virtualenvin码头工人 那么 它是虚拟环境中的虚拟环境
  • 如何监控 Tomcat 服务器上的多个 Web 应用程序(使用 JMX)?

    有没有办法监控单个 Web 应用程序的 CPU 和内存消耗Tomcat server I have Tomcat打开其 JVM 下的所有 Web 应用程序 因此我只能看到一个 JVM 无法单独监控每个 Web 应用程序 Web 应用程序是密
  • jquery中删除多个元素

    在我当前的代码中我有这样的 foo remove bar remove 有没有办法通过使用删除多个元素remove once 它不限于 remove 但只需用逗号分隔选择器 foo bar remove 多重选择器 选择器1 选择器2 选择
  • 如何调整 UIImageView 的大小以适应底层图像而不移动它?

    我有一个 UIImageView 其框架在加载图像之前设置 对于图像来说总是太大 因此 例如 当我尝试圆角时 什么也没有发生 如何调整框架大小 使其与底层图像的大小相同 同时确保 UIImageView 的中心点不会改变 如果更改 UIVi
  • CPython的静态对象地址和碎片

    I read 对于Python来说 if x 是存储x的内存地址 这是给定的id对象的属性永远不会改变 这意味着对象在其生命周期中始终存储在给定的内存地址中 这就引出了一个问题 虚拟 内存碎片怎么样 说一个物体A位于地址 1 有id1 占用
  • IIS7的工作进程是什么?

    我正在尝试在 Visual Studio 2008 中执行 附加到进程 进行调试 但我无法弄清楚要附加到哪个进程 帮助 事实上它仍然是 w3wp exe 您需要检查 显示所有会话中的进程 让它显示的选项 这也让我困惑了一段时间
  • 如果不调用notify(),等待线程会发生什么?

    如果不调用notify 等待线程会发生什么 这是虚假唤醒吗 If a waiting Thread is not notified通过致电notify or notifyAll 在所述线程正在等待的对象上 则可能会发生以下任一情况 the
  • Chrome 调试协议:HeapProfiler.getHeapSnapshot 忽略回调

    我正在开发一个测试套件 作为 Chrome 扩展实现 该套件使用 Chrome Chromium 的远程调试协议以编程方式获取和分析堆快照 因为Profiler 似乎不是公共协议的一部分 我正在使用这一页 http trac webkit
  • HTML 多选框

    我只是想知道下面的表格的名称是什么 我从早上就在谷歌上搜索 HTML 表单列表 但我在任何地方都找不到这种表单 谁能告诉我这个表单的确切名称以及它是否可以在 HTML 表单中使用 我只想在我的网站中添加这种形式 它适用于 HTML 还是我应
  • 将变量传递给 Google Cloud Functions

    我刚刚在 Beta Python 3 7 运行时中使用 HTTP 触发器编写了 Google Cloud Function 现在我试图弄清楚如何在调用函数时将字符串变量传递给函数 我已阅读文档 但没有找到任何相关内容 我的触发器类似于 ht
  • 如何在光线平行且不使用光线模式的情况下运行函数?

    After sudo pip3 install ray 我创建了一个函数foo 在射线装饰器中定义 import ray ray init ray remote def foo x print x 我希望能够使用foo并行和常规模式 忽略装
  • ViewModel 和 Service 类的实例化

    我试图理解 ViewModel 和 Service 类的实例化 并将其写下来供其他人使用 请在需要的地方更正 添加 ViewModel 和服务的实例化并不是以最常见的方式完成的 这是使用反射完成的 在 TipCalc 中 您有 public
  • 在特定日期触发 UILocalNotification

    我想开火UILocalNotification在特定日期 如果我使用这段代码 NSCalendar gregorian NSCalendar alloc initWithCalendarIdentifier NSGregorianCalen
  • 尝试理解 Pytorch 的 LSTM 实现

    我有一个包含 1000 个示例的数据集 其中每个示例都有5特征 a b c d e 我想喂7LSTM 的示例 以便它预测第 8 天的特征 a 阅读 nn LSTM 的 Pytorchs 文档 我得出以下结论 input size 5 hid