RuntimeError: 维度超出范围(预期在 [-1, 0] 范围内,但得到 1)

2024-01-13

我使用 Pytorch Unet 模型,将图像作为输入,同时将标签作为输入图像掩码,并在其上训练数据集。 我从其他地方获得的 Unet 模型,我使用交叉熵损失作为损失函数,但我得到了这个维度超出范围的错误,

RuntimeError                              
Traceback (most recent call last)
<ipython-input-358-fa0ef49a43ae> in <module>()
     16 for epoch in range(0, num_epochs):
     17     # train for one epoch
---> 18     curr_loss = train(train_loader, model, criterion, epoch, num_epochs)
     19 
     20     # store best loss and save a model checkpoint

<ipython-input-356-1bd6c6c281fb> in train(train_loader, model, criterion, epoch, num_epochs)
     16         # measure loss
     17         print (outputs.size(),labels.size())
---> 18         loss = criterion(outputs, labels)
     19         losses.update(loss.data[0], images.size(0))
     20 

/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py in     _ _call__(self, *input, **kwargs)
    323         for hook in self._forward_pre_hooks.values():
    324             hook(self, input)
--> 325         result = self.forward(*input, **kwargs)
    326         for hook in self._forward_hooks.values():
    327             hook_result = hook(self, input, result)

<ipython-input-355-db66abcdb074> in forward(self, logits, targets)
      9         probs_flat = probs.view(-1)
     10         targets_flat = targets.view(-1)
---> 11         return self.crossEntropy_loss(probs_flat, targets_flat)

/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py in     __call__(self, *input, **kwargs)
    323         for hook in self._forward_pre_hooks.values():
    324             hook(self, input)
  --> 325         result = self.forward(*input, **kwargs)
    326         for hook in self._forward_hooks.values():
    327             hook_result = hook(self, input, result)

/usr/local/lib/python3.5/dist-packages/torch/nn/modules/loss.py in f orward(self, input, target)
    599         _assert_no_grad(target)
    600         return F.cross_entropy(input, target, self.weight, self.size_average,
--> 601                                self.ignore_index, self.reduce)
    602 
    603 

/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py in     cross_entropy(input, target, weight, size_average, ignore_index, reduce)
   1138         >>> loss.backward()
   1139     """
-> 1140     return nll_loss(log_softmax(input, 1), target, weight, size_average, ignore_index, reduce)
   1141 
   1142 

/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py in     log_softmax(input, dim, _stacklevel)
    784     if dim is None:
    785         dim = _get_softmax_dim('log_softmax', input.dim(),      _stacklevel)
--> 786     return torch._C._nn.log_softmax(input, dim)
    787 
    788 

RuntimeError: dimension out of range (expected to be in range of [-1, 0], but got 1)

我的部分代码如下所示

class crossEntropy(nn.Module):
    def __init__(self, weight = None, size_average = True):
        super(crossEntropy, self).__init__()
        self.crossEntropy_loss = nn.CrossEntropyLoss(weight, size_average)
        
    def forward(self, logits, targets):
        probs = F.sigmoid(logits)
        probs_flat = probs.view(-1)
        targets_flat = targets.view(-1)
        return self.crossEntropy_loss(probs_flat, targets_flat)


class UNet(nn.Module):
    def __init__(self, imsize):
        super(UNet, self).__init__()
        self.imsize = imsize

        self.activation = F.relu
        
        self.pool1 = nn.MaxPool2d(2)
        self.pool2 = nn.MaxPool2d(2)
        self.pool3 = nn.MaxPool2d(2)
        self.pool4 = nn.MaxPool2d(2)
        self.conv_block1_64 = UNetConvBlock(4, 64)
        self.conv_block64_128 = UNetConvBlock(64, 128)
        self.conv_block128_256 = UNetConvBlock(128, 256)
        self.conv_block256_512 = UNetConvBlock(256, 512)
        self.conv_block512_1024 = UNetConvBlock(512, 1024)

        self.up_block1024_512 = UNetUpBlock(1024, 512)
        self.up_block512_256 = UNetUpBlock(512, 256)
        self.up_block256_128 = UNetUpBlock(256, 128)
        self.up_block128_64 = UNetUpBlock(128, 64)

        self.last = nn.Conv2d(64, 2, 1)


    def forward(self, x):
        block1 = self.conv_block1_64(x)
        pool1 = self.pool1(block1)

        block2 = self.conv_block64_128(pool1)
        pool2 = self.pool2(block2)

        block3 = self.conv_block128_256(pool2)
        pool3 = self.pool3(block3)

        block4 = self.conv_block256_512(pool3)
        pool4 = self.pool4(block4)

        block5 = self.conv_block512_1024(pool4)

        up1 = self.up_block1024_512(block5, block4)

        up2 = self.up_block512_256(up1, block3)

        up3 = self.up_block256_128(up2, block2)

        up4 = self.up_block128_64(up3, block1)

        return F.log_softmax(self.last(up4))

根据你的代码:

probs_flat = probs.view(-1)
targets_flat = targets.view(-1)
return self.crossEntropy_loss(probs_flat, targets_flat)

你给了两个一维张量nn.CrossEntropyLoss但根据文档 http://pytorch.org/docs/0.3.0/nn.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss,它期望:

Input: (N,C) where C = number of classes
Target: (N) where each value is 0 <= targets[i] <= C-1
Output: scalar. If reduce is False, then (N) instead.

我相信这就是您遇到问题的原因。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

RuntimeError: 维度超出范围(预期在 [-1, 0] 范围内,但得到 1) 的相关文章

随机推荐

  • 响应式网页设计技巧、最佳实践和动态图像缩放技术[关闭]

    Closed 这个问题是基于意见的 help closed questions 目前不接受答案 我希望这个问题不会因为主题太宽泛而结束 但我想知道响应式 自适应网页设计 即适用于所有浏览器 所有设备的一个网站 在结构和布局方面 实现此类网站
  • Firestore / React 在 componentWillUnmount 中需要取消订阅

    我有一个位于特定路线的组件
  • AWS Amplify:DevTools 无法加载 SourceMap:JSON 中位置 0 处出现意外标记 <

    在 Google Chrome 上加载我的网站时 我收到了一些类似于以下内容的警告 DevTools failed to load SourceMap Could not parse content for https mywebsite
  • 通过 NSPredicate 在 NSString 中进行“整个单词”搜索

    我想在属性中搜索description an NSString实例 有一个给定的单词 我尝试使用这个谓词 NSPredicate predicateWithFormat description CONTAINS cd theWord 它有效
  • 撇号的正则表达式

    我正在寻找一个正则表达式来查找字符串中的撇号 该字符串也可以是一个句子 我尝试了一个简单的正则表达式 如 但它只检查字符串中的一个字符 如何检查整个字符串 例如 Hello I have many PC s 将是一场比赛 但 I dont
  • 保持 git clean 历史记录的最佳实践是什么?

    在阅读有关 git 工作流程的文章时 我想知道历史重写的适当性 我的工作流程以及我想象的许多其他人的工作流程是这样的 获取 Github 存储库 我们称其为rep1 制作一个叉子 这将是rep2 git 将其克隆到本地以进行使用 即rep3
  • “ascii”编解码器无法对位置 * 或不在范围内的字符进行编码 (128)

    stackoverflow 上有一些线程 但我找不到整个问题的有效解决方案 我从 urllib 读取函数收集了大量文本数据并将其存储在 pickle 文件中 现在我想将这些数据写入文件 写作时我遇到类似的错误 ascii codec can
  • 如何在使用 Eigen Library C++ 时删除特定行或列

    我正在为我的项目使用 Eigen 库 我正在搜索如何从特征中的给定矩阵中删除特定行或列 我没有成功 MatrixXd A X1 X2 X3 X4 Y1 Y2 Y3 Y4 Z1 Z2 Z3 Z4 A1 A2 A3 A4 MatrixXd At
  • SQL Server中两个日期之间的月差

    请参考以下示例 并请告诉我您的想法 declare EmployeeStartDate datetime 01 Sep 2013 declare EmployeeEndDate datetime 15 Nov 2013 select Dat
  • ARCore 在按钮单击时保存相机图像 (Unity C#)

    我有一个类似的问题 例如以下三个问题 将 Unity ARCore 中的 AcquireCameraImageBytes 作为图像保存到存储 https stackoverflow com questions 49579334 save a
  • 我可以从 Google 表格脚本生成文件吗?

    我正在使用 Google Sheets 为我正在做的事情制作一堆数值数据的原型 有没有办法将子集导出到文本文件 实际上 我的目标是导出一个可以直接包含在另一个项目的构建中的文件 那么有没有办法生成文本文件供下载呢 如果您有 Google A
  • 文件读取器内存泄漏

    我正在使用 FileReader 将图像文件上传到客户端 用于数据获取和缩略图显示 我注意到的是 在页面进程上 在任务管理器中 内存只会越来越高 当进程停止时 内存保持在高位并且永远不会下降 你能告诉我我在这里做错了什么吗 如需查看 请上传
  • 找不到模块:无法解析“@date-io/date-fns”

    我在用着反应材料用户界面 https material ui com我收到此错误 找不到模块 无法解析 date io date fns 以下是我的 package json 文件中的依赖项 dependencies date io dat
  • 当需要日志记录时,您会考虑哪种设计模式?

    我正在开发的应用程序需要将操作 执行该操作的用户以及操作时间记录到数据库中 哪种设计模式最流行 最适合日志记录 我在想命令模式需要当前用户和操作 执行操作并写入日志 你怎么认为 我可以考虑其他替代方案吗 谢谢 您可以使用AOP http e
  • 如何删除Jenkins下的View而不影响现有作业

    我想删除Jenkins下的Views而不影响视图下的Jobs 我之所以问这个问题 是因为即使以管理员身份删除它后 我也无法输入相同的视图名称 我检查了 Jenkins 文件夹下的配置文件并尝试编辑视图名称 但这不起作用 我需要确认以下脚本是
  • 图钉调整绑定缩放级别大小

    我将 WinRT 与 bing 地图结合使用 并尝试在缩放地图时设置 以编程方式 图钉的 RenderTransform 值 我试过这个Solution http social msdn microsoft com Forums en US
  • Django - 如何在不修改的情况下扩展第 3 方模型

    我想向数据库表添加一列 但我不想修改第 3 方模块 以防我将来需要 决定升级模块 有没有办法可以在我的代码中添加此字段 以便在新版本中我不必手动添加该字段 您可以使用 ModelName add to class 或 contribute
  • pip3 ImportError:无法导入名称“IncompleteRead”

    通过安装模块时遇到问题pip3 尝试了 2014 年 12 月以来投票最高的帖子中的一些建议 但仍然得到以下结果 sudo pip3 install send2trash Traceback most recent call last Fi
  • 使用 GAE 限制对静态文件的访问

    我有一个静态文件 我不想公开该文件 有没有办法限制 app yaml 的访问 使其只能由自己的域加载 基于 web2py 的解决方案也很受欢迎 因为我在 GAE 之上使用它 Thanks 您可以使用 登录 必需 来限制对其的访问 以要求使用
  • RuntimeError: 维度超出范围(预期在 [-1, 0] 范围内,但得到 1)

    我使用 Pytorch Unet 模型 将图像作为输入 同时将标签作为输入图像掩码 并在其上训练数据集 我从其他地方获得的 Unet 模型 我使用交叉熵损失作为损失函数 但我得到了这个维度超出范围的错误 RuntimeError Trace