PyTorch:DecoderRNN:运行时错误:输入必须有 3 个维度,得到 2 个维度

2023-12-01

我正在使用 PyTorch 构建 DecoderRNN (这是一个图像标题解码器):

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size):

        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.gru = nn.GRU(embed_size, hidden_size, hidden_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, features, captions):

        print (features.shape)
        print (captions.shape)
        output, hidden = self.gru(features, captions)
        output = self.softmax(self.out(output[0]))
        return output, hidden 

数据具有以下形状:

torch.Size([10, 200])  <- features.shape (10 for batch size)
torch.Size([10, 12])   <- captions.shape (10 for batch size)

然后我收到以下错误。我在这里错过了什么想法吗?谢谢!

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-2-76e05ba08b1d> in <module>()
     44         # Pass the inputs through the CNN-RNN model.
     45         features = encoder(images)
---> 46         outputs = decoder(features, captions)
     47 
     48         # Calculate the batch loss.

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    323         for hook in self._forward_pre_hooks.values():
    324             hook(self, input)
--> 325         result = self.forward(*input, **kwargs)
    326         for hook in self._forward_hooks.values():
    327             hook_result = hook(self, input, result)

/home/workspace/model.py in forward(self, features, captions)
     37         print (captions.shape)
     38         # features = features.unsqueeze(1)
---> 39         output, hidden = self.gru(features, captions)
     40         output = self.softmax(self.out(output[0]))
     41         return output, hidden

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    323         for hook in self._forward_pre_hooks.values():
    324             hook(self, input)
--> 325         result = self.forward(*input, **kwargs)
    326         for hook in self._forward_hooks.values():
    327             hook_result = hook(self, input, result)

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/rnn.py in forward(self, input, hx)
    167             flat_weight=flat_weight
    168         )
--> 169         output, hidden = func(input, self.all_weights, hx)
    170         if is_packed:
    171             output = PackedSequence(output, batch_sizes)

/opt/conda/lib/python3.6/site-packages/torch/nn/_functions/rnn.py in forward(input, *fargs, **fkwargs)
    383             return hack_onnx_rnn((input,) + fargs, output, args, kwargs)
    384         else:
--> 385             return func(input, *fargs, **fkwargs)
    386 
    387     return forward

/opt/conda/lib/python3.6/site-packages/torch/autograd/function.py in _do_forward(self, *input)
    326         self._nested_input = input
    327         flat_input = tuple(_iter_variables(input))
--> 328         flat_output = super(NestedIOFunction, self)._do_forward(*flat_input)
    329         nested_output = self._nested_output
    330         nested_variables = _unflatten(flat_output, self._nested_output)

/opt/conda/lib/python3.6/site-packages/torch/autograd/function.py in forward(self, *args)
    348     def forward(self, *args):
    349         nested_tensors = _map_variable_tensor(self._nested_input)
--> 350         result = self.forward_extended(*nested_tensors)
    351         del self._nested_input
    352         self._nested_output = result

/opt/conda/lib/python3.6/site-packages/torch/nn/_functions/rnn.py in forward_extended(self, input, weight, hx)
    292             hy = tuple(h.new() for h in hx)
    293 
--> 294         cudnn.rnn.forward(self, input, hx, weight, output, hy)
    295 
    296         self.save_for_backward(input, hx, weight, output)

/opt/conda/lib/python3.6/site-packages/torch/backends/cudnn/rnn.py in forward(fn, input, hx, weight, output, hy)
    206         if (not is_input_packed and input.dim() != 3) or (is_input_packed and input.dim() != 2):
    207             raise RuntimeError(
--> 208                 'input must have 3 dimensions, got {}'.format(input.dim()))
    209         if fn.input_size != input.size(-1):
    210             raise RuntimeError('input.size(-1) must be equal to input_size. Expected {}, got {}'.format(

RuntimeError: input must have 3 dimensions, got 2

您的 GRU 输入需要是 3 维的:

input形状(seq_len,batch,input_size):包含输入序列特征的张量。

此外,您需要提供隐藏状态(在本例中为最后一个编码器隐藏状态)作为第二个参数:

self.gru(input, h_0)

Where input是你的实际输入h_0隐藏状态也需要是 3 维的:

h_0形状(num_layers * num_directions、batch、hidden_​​size):张量 包含批次中每个元素的初始隐藏状态。 如果未提供,则默认为零。

https://pytorch.org/docs/master/nn.html#torch.nn.GRU

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch:DecoderRNN:运行时错误:输入必须有 3 个维度,得到 2 个维度 的相关文章

随机推荐

  • MySQL 对字符串第一部分的索引

    我正在 MySQL 中查询一个非常大的表 超过 3M 条记录 其中包含category id subcategory id 和邮政编码 数据库中的 zip 可能是也可能不是 10 个字符 目的是获取指定邮政编码的特定半径内的所有目录 子目录
  • Spring 注解 - 注入对象映射

    使用 XML 注释 我使用以下配置注入地图
  • PHP debug_backtrace 在生产代码中获取有关调用方法的信息?

    是否有令人信服的理由不使用debug backtrace仅仅是为了确定调用方法的类 名称和参数列表吗 不用于调试目的 它的函数名称中有 debug 一词 这让我觉得以这种方式使用它有点肮脏 但它符合我需要做的事情 一个可以从许多地方调用的单
  • 嵌入式linux ARM启动地址

    我按照一些文档通过 sdcard 在 ARM 板 例如 Freescale Vybrid tower 上启动嵌入式 Linux 在文档中 有构建 uImage 并将 u boot 写入 sdcard 的步骤 如下所示 sudo dd if
  • LNK2019问题

    我有一个LNK2019尝试在我的项目中使用某些 DLL 时出现问题 Details 我有一个名为 dll1 的 DLL 项目 编译得很好 使用 declspec dllexport 以便导出 dll1 内的类 供 dll2 使用 我有另一个
  • 如何将现有的 React 应用程序(只是一个没有后端的 UI)插入(注入?)到 SilverStripe 页面布局中?

    我的问题是 我一直在阅读 SilverStripe 4 文档 以便找到一种将现有 React 应用程序 只是没有后端的嵌套 React 组件的 UI 插入 SilverStripe 页面布局的方法 这可能吗 如何确保 SilverStrip
  • php 无法在 wampserver 的 html 代码中工作

    事情是这样的 我有一个名为first php 的文件 其中包含以下代码 welcome br 但是 当我执行它时 php 代码不会被解释 短开标签似乎也已打开 我正在使用 wampserver 我错过了什么 您的服务器似乎配置错 误 您的
  • python 数组赋值与标量赋值

    我有一个二维数组A形状的 4 3 和一个一维数组a形状的 4 我想交换前两行A 以及中的前两个元素a 我做了以下事情 A 0 A 1 A 1 A 0 a 0 a 1 a 1 a 0 显然 它适用于a 但失败了A 现在 第二行成为第一行 但第
  • Selenium IDE:将测试脚本包含到新的测试脚本中

    我们谷歌找到解决方案但没有成功 我们如何将已经录制的脚本添加到新脚本中 Selenium Core有一个扩展 include 可以将另一个测试的内容添加到当前测试中 这是 OpenQA wiki 上的页面 http wiki openqa
  • 编写一个终端仿真器,里面有什么?

    这有点关系到这个问题关于 cmd exe 的更好的 shell 终端 gui 界面 在我寻找更好的 shell 终端的过程中 我遇到的唯一有用的东西是Console2 其他替代品不是免费的 而且通常不会比 Console2 提供更多的功能来
  • 我在这个乒乓球游戏中制作了一个边界,但球拍可以穿过它。我该如何阻止呢?

    我在这个乒乓球游戏中做了一个边框 屏幕上的球拍可以越过它 我之前已经在另一段代码中完成了此操作 但现在一切都不同了 我有一个关于如何做到这一点的主要想法 你可能需要一个 if 语句 但我没有一切 您可以删除 pygame load imag
  • 使用模式在 Jasper Reports 中设置货币格式

    我有一个查询从表中返回金额 select bus price from mySchema BusTable 这将返回如下金额 526547 123456 456789 25 12478 35 我在贾斯珀报告中使用了上述金额 但是 我希望报告
  • 快速裁剪视频

    我正在方形 UIView 中录制视频 但是当我导出视频时 视频是全屏 1080x1920 现在我想知道如何将视频从全屏缩小为方形比例 1 1 以下是我设置摄像机的方法 session AVCaptureSession for device
  • 数据匹配算法

    我目前正在开展一个项目 需要实现数据匹配算法 外部系统传递它所知道的有关客户的所有数据 而我设计的系统必须返回匹配的客户 因此 外部系统知道客户的正确 ID 并获取其他数据或可以更新其自己的特定客户数据 传入以下字段 Name Name2
  • 使用 .clone() 复制二维数组仍然引用原始数据

    好的 我知道这个问题之前已经被问过 上一个问题 我还研究了其他一些线程和网站 它们似乎都产生了比答案更多的问题 乔什 布洛赫谈设计 一篇文章讨论 clone 但我仍然无法找到问题的答案 当我克隆二维数组时 values Map mapVal
  • Postgres 连接表的唯一多列索引

    我在 Postgres 中有一个多对多连接表 我想将其索引到 A 提高性能 显然 和 B 强制唯一性 例如 a id b id 1 2 lt okay 1 3 lt okay 2 3 lt okay 1 3 lt not okay same
  • 回显到文件而不带换行符(批量)[重复]

    这个问题在这里已经有答案了 我的生活创造者计划有问题 它只是选择代表某些内容的随机数并将其放在一起 我尝试过组合变量 这是代码 set num 1 SET A a RANDOM 10 32768 1 if a 10 set life num
  • 在哪里为所有 HttpRequest 设置自定义 ClaimsPrincipal

    我正在将旧应用程序移植到 ASP NET Core 它使用 Windows 身份验证 在 IIS 中配置 分别为 launchsetting json 在开发模式下运行时 我想覆盖身份验证以使用自定义硬编码的 ClaimsPrincipal
  • 如何防止 QTableview 中过于激进的文本删除?

    I have an issue with text elide in Qt being too aggressive in a table see picture 带有完整数字0 8888的单元格 自从显示QTableWidget以来我已经
  • PyTorch:DecoderRNN:运行时错误:输入必须有 3 个维度,得到 2 个维度

    我正在使用 PyTorch 构建 DecoderRNN 这是一个图像标题解码器 class DecoderRNN nn Module def init self embed size hidden size vocab size super