pytorch 中图像分割的通道明智 CrossEntropyLoss

2024-02-01

我正在做图像分割任务。总共有 7 个类,所以最终的输出是像 [batch, 7, height, width] 这样的张量,它是一个 softmax 输出。现在直觉上我想使用 CrossEntropy 损失,但 pytorch 实现不适用于通道明智的单热编码向量

所以我打算自己做一个功能。在一些 stackoverflow 的帮助下,我的代码到目前为止看起来像这样

from torch.autograd import Variable
import torch
import torch.nn.functional as F


def cross_entropy2d(input, target, weight=None, size_average=True):
    # input: (n, c, w, z), target: (n, w, z)
    n, c, w, z = input.size()
    # log_p: (n, c, w, z)
    log_p = F.log_softmax(input, dim=1)
    # log_p: (n*w*z, c)
    log_p = log_p.permute(0, 3, 2, 1).contiguous().view(-1, c)  # make class dimension last dimension
    log_p = log_p[
       target.view(n, w, z, 1).repeat(0, 0, 0, c) >= 0]  # this looks wrong -> Should rather be a one-hot vector
    log_p = log_p.view(-1, c)
    # target: (n*w*z,)
    mask = target >= 0
    target = target[mask]
    loss = F.nll_loss(log_p, target.view(-1), weight=weight, size_average=False)
    if size_average:
        loss /= mask.data.sum()
    return loss


images = Variable(torch.randn(5, 3, 4, 4))
labels = Variable(torch.LongTensor(5, 3, 4, 4).random_(3))
cross_entropy2d(images, labels)

我收到两个错误。代码本身提到了一个,它需要 one-hot 向量。第2个说的是以下内容

RuntimeError: invalid argument 2: size '[5 x 4 x 4 x 1]' is invalid for input with 3840 elements at ..\src\TH\THStorage.c:41

例如,我试图让它解决 3 类问题。所以目标和标签是(为了简化不包括批处理参数!)

Target:

 Channel 1     Channel 2  Channel 3

[[0 1 1 0 ] [0 0 0 1 ] [1 0 0 0 ] [0 0 1 1 ] [0 0 0 0 ] [1 1 0 0 ] [0 0 0 1 ] [0 0 0 0 ] [1 1 1 0 ] [0 0 0 0 ] [0 0 0 1 ] [1 1 1 0 ]

Labels:

 Channel 1     Channel 2  Channel 3

[[0 1 1 0 ] [0 0 0 1 ] [1 0 0 0 ] [0 0 1 1 ] [.2 0 0 0] [.8 1 0 0 ] [0 0 0 1 ] [0 0 0 0 ] [1 1 1 0 ] [0 0 0 0 ] [0 0 0 1 ] [1 1 1 0 ]

那么我如何修复我的代码来计算通道明智的 CrossEntropy 损失?


正如 Shai 的回答已经指出的那样,关于torch.nn.CrossEntropy()可以找到函数here https://pytorch.org/docs/stable/nn.html#torch.nn.CrossEntropyLoss并且可以找到代码here https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html。内置函数确实已经支持 KD 交叉熵损失。

在 3D 情况下,torch.nn.CrossEntropy()函数需要两个参数:4D 输入矩阵和 3D 目标矩阵。输入矩阵的形状为:(Minibatch, Classes, H, W)。目标矩阵的形状为 (Minibatch, H, W),数字范围为 0 到 (Classes-1)。如果您从 one-hot 编码矩阵开始,则必须将其转换为np.argmax().

具有三个类别且小批量大小为 1 的示例:

import pytorch
import numpy as np

input_torch = torch.randn(1, 3, 2, 5, requires_grad=True)

one_hot = np.array([[[1, 1, 1, 0, 0], [0, 0, 0, 0, 0]],    
                    [[0, 0, 0, 0, 0], [1, 1, 1, 0, 0]],
                    [[0, 0, 0, 1, 1], [0, 0, 0, 1, 1]]])

target = np.array([np.argmax(a, axis = 0) for a in target])
target_torch = torch.tensor(target_argmax)

loss = torch.nn.CrossEntropyLoss()
output = loss(input_torch, target_torch)
output.backward()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch 中图像分割的通道明智 CrossEntropyLoss 的相关文章

随机推荐

  • if-let 语句不会解开可选内容

    我在我的代码中遇到了一些看起来很好奇的东西 并且想知道这种行为是否有一个简单的解释 鉴于以下声明 if let tabBarController topViewController as UITabBarController for sub
  • 如何推迟shared_ptr的删除操作?

    我创建了一个指针sample主要类 我正在将此指针传递给函数function1 该函数必须使用指针作为共享指针并使用该指针执行一些操作 退出期间function1 的析构函数sample由于调用shared ptr 当我将相同的指针传递给不
  • 实现自定义 LINQ-to-X 提供程序

    我有一个搜索工具 它接受复杂的搜索字符串 实际上是 JSON 中的 n 级对象图 并返回一些结果 我想通过类似 LINQ 的机制向其他 内部 开发人员公开该功能 假设每个结果都是由一个类定义的Result 我可以创建类似以下的方法 Func
  • Google+ 的共享意图无法访问图像

    我正在调用共享图像的意图 这适用于大多数提供商 但适用于 Google Google 打开不带图片的帖子活动 并显示 Toast 您只能发布存储在设备上的照片 同时 File f storeImage image f data data c
  • 在头文件与 .cpp 文件中编码 C++(主要)

    多年来 我一直以标准方式编写 C 代码 在头文件 hpp 中使用类声明 在源 cpp 文件中使用函数定义 最近 我搬到了一家新公司 其中的代码 似乎受到 boost 编码风格的影响 完全用 hpp 文件进行编码 并用一个短的 cpp 文件来
  • Visual Studio Extensions - 支持多个版本的VS

    我一直在编写一个扩展 编辑器分类器项目 带有一些其他功能 它在 VS2013 上运行良好 但我需要支持其他版本 VS2012 和 VS2015 当它超出预览时 当我刚刚添加支持的版本时vsixmanifest 我面临的问题是 出口ITest
  • js函数不调用自动填充函数

    我有一个输入字段 其中包含州名称并在其他字段上显示相应的区域 当我更改状态字段的值时 区域的值也应该更改 它不适用于我的代码 这有什么问题 div class col md 4 form group style padding right
  • MongoDb 使用位置运算符拉取

    我的文档结构如下 id 12342342 items ownerId 123 dates 2014 10 01 2014 10 02 ownerId 234 dates 2014 10 01 2014 10 02 我想从父对象的ownerI
  • 低延迟、大规模消息队列

    我正在重新思考 Facebook 应用程序和云计算时代的大型多人游戏 假设我要在现有开放协议之上构建一些东西 并且我想为 1 000 000 个同时玩家提供服务 只是为了解决问题 假设每个玩家都有一个传入消息队列 用于聊天等 平均还有一个传
  • 查找带有 USB 设备 VID/PID 的 /dev 条目

    我想制作一个程序来检测哪些 dev sd 条目链接到已知的 USB VID PID 对 你知道我如何获得 USB 记忆棒的 VID PID 吗 dev sd 您可以使用udevadm为了这 在输出中udevadm info q proper
  • 通过 UIBezierPath 移动 CALayer

    我有一个图层将从 UIBezierPath 上的 A 点移动到 B 点 我发现了很多涉及 CAAnimation 和 UIBezierPath 的示例 但我只需将图层从指定点移动到贝塞尔曲线路径上的另一个点 任何建议 将不胜感激 Thank
  • 裁剪图像时添加细白线 (Objective-C OSX)

    我正在剪切一张大图像并将其保存为许多不同的图像 我首先在iOS它工作正常 但是当我尝试将代码移植到OSX 一条细白线 1 像素 出现在图像的顶部和右侧 该线不是纯白色或实线 参见下面的示例 这里是iOS制作一个子图像的代码 就像冠军一样 v
  • 如何从 HIVE 中的日期减去月份

    我正在寻找一种方法来帮助我从 HIVE 中的日期中减去月份 我有个约会2015 02 01 现在我需要从这个日期减去 2 个月 这样结果应该是2014 12 01 你们能帮我一下吗 select add months 2015 02 01
  • BigQuery 选择 * 两列除外

    我想从公共 BigQuery github repos 数据集中选择除以下两条记录之外的所有内容 author nameAND差异 old mode 根据我问的类似问题 我想我想运行类似于 standardSQL SELECT REPLAC
  • 如何仅在functions.auth.user().onCreate完成后完成登录

    我正在使用 firebase 函数 并且有一个在用户创建时添加新集合的函数 问题是有时用户在功能完成之前登录 因此用户已登录但尚未创建新集合 然后我收到错误消息 缺少或权限不足 因为规则找不到该集合 我该如何处理 仅当所有内容都来自时 是否
  • 在表格单元格内垂直拉伸 div - IE8

    如何在表格单元格内垂直拉伸 DIV 我想height 100 就可以了但在某些情况下 事实并非如此 至少在 IE8 中 这是一个简单的例子 一个 3 行表格 包含页眉 内容和页脚 我希望 内容 单元格内的 内容 DIV 垂直拉伸 100 在
  • 如何组合 kotlin 委托属性:可观察、可否决和“按映射”?

    我正在尝试结合代表 可观察的 https kotlinlang org api latest jvm stdlib kotlin properties delegates observable html with vetoable http
  • .NET Framework 4.0 安装程序是否也安装了 .NET 3.5?

    NET 4 0 旨在与 3 5 并行运行 并且不会运行 3 5 应用程序 这让我担心必须指示我的用户下载 NET 3 5 而不仅仅是 最新版本 我在一篇博客中读到 如果尚未安装 4 0 安装程序也会安装 3 5 但我现在无法测试它 有人尝试
  • 如何使用 XPath 和 Java 更新 XML

    我有一个 XML 文档以及该文档的 XPath 表达式 我必须在运行时使用 XPath 更新文档 我如何使用 Java 来做到这一点 下面是我的xml
  • pytorch 中图像分割的通道明智 CrossEntropyLoss

    我正在做图像分割任务 总共有 7 个类 所以最终的输出是像 batch 7 height width 这样的张量 它是一个 softmax 输出 现在直觉上我想使用 CrossEntropy 损失 但 pytorch 实现不适用于通道明智的