预期设备类型为 cuda 的对象,但在 Pytorch 中获得了设备类型 cpu

2024-05-07

我有以下计算损失函数的代码:

class MSE_loss(nn.Module):
    """ 
    : metric: L1, L2 norms or cosine similarity
    : mode: training or evaluation mode
    """

    def __init__(self,metric, mode, weighted_sum = False):
        super(MSE_loss, self).__init__()
        self.metric = metric.lower()
        self.loss_function = nn.MSELoss()
        self.mode = mode.lower()
        self.weighted_sum = weighted_sum

    def forward(self, output1, output2, labels):
        self.labels = labels         
        self.linear = nn.Linear(output1.size()[0],1)

        if self.metric == 'cos':
            self.d= F.cosine_similarity(output1, output2)
        elif self.metric == 'l1':
            self.d = torch.abs(output1-output2)
        elif self.metric == 'l2':
            self.d = torch.sqrt((output1-output2)**2)

        def dimensional_reduction(forward):
            if self.weighted_sum:
                distance = self.linear(self.d)
            else:
                distance = torch.mean(self.d,1)
            return distance

        def estimate_loss(forward):
            distance = dimensional_reduction(self.d)
            pred = torch.exp(-distance)
            pred = torch.round(pred)
            loss = self.loss_function(pred, self.labels)
            return pred, loss

        pred, loss = estimate_loss(self.d)

        if self.mode == 'training':
            return loss
        else:
            return pred, loss

Given

criterion = MSE_loss('l1','training', weighted_sum = True)

我想在执行标准时通过自线性神经元后获得距离。但是,我收到错误提示“预期设备类型为 cuda 对象,但在调用 _th_addmm 时获得了参数 #1 'self' 的设备类型 cpu”,表明出现了问题。我省略了代码的第一部分,但我提供了完整的错误消息,以便您可以了解发生了什么。

RuntimeError                              Traceback (most recent call last)
<ipython-input-253-781ed4791260> in <module>()
      7 criterion = MSE_loss('l1','training', weighted_sum = True)
      8 
----> 9 train(test_net, train_loader, 10, batch_size, optimiser, clip, criterion)

<ipython-input-207-02fecbfe3b1c> in train(SNN, dataloader, epochs, batch_size, optimiser, clip, criterion)
     57 
     58             # calculate the loss and perform backprop
---> 59             loss = criterion(output1, output2, labels)
     60             a = [[n,p, p.grad] for n,p in SNN.named_parameters()]
     61 

~/.conda/envs/dalkeCourse/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    539             result = self._slow_forward(*input, **kwargs)
    540         else:
--> 541             result = self.forward(*input, **kwargs)
    542         for hook in self._forward_hooks.values():
    543             hook_result = hook(self, input, result)

<ipython-input-248-fb88b987ce71> in forward(self, output1, output2, labels)
     49             return pred, loss
     50 
---> 51         pred, loss = estimate_loss(self.d)
     52 
     53         if self.mode == 'training':

<ipython-input-248-fb88b987ce71> in estimate_loss(forward)
     43 
     44         def estimate_loss(forward):
---> 45             distance = dimensional_reduction(self.d)
     46             pred = torch.exp(-distance)
     47             pred = torch.round(pred)

<ipython-input-248-fb88b987ce71> in dimensional_reduction(forward)
     36             else:
     37                 if self.weighted_sum:
---> 38                     self.d = self.linear(self.d)
     39                 else:
     40                     self.d = torch.mean(self.d,1)

~/.conda/envs/dalkeCourse/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    539             result = self._slow_forward(*input, **kwargs)
    540         else:
--> 541             result = self.forward(*input, **kwargs)
    542         for hook in self._forward_hooks.values():
    543             hook_result = hook(self, input, result)

~/.conda/envs/dalkeCourse/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)
     85 
     86     def forward(self, input):
---> 87         return F.linear(input, self.weight, self.bias)
     88 
     89     def extra_repr(self):

~/.conda/envs/dalkeCourse/lib/python3.6/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1368     if input.dim() == 2 and bias is not None:
   1369         # fused op is marginally faster
-> 1370         ret = torch.addmm(bias, input, weight.t())
   1371     else:
   1372         output = input.matmul(weight.t())

RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_addmm

然而 self.d 是一个张量,但它已经被传递到 GPU 中,如下所示:

self.d =
tensor([[3.7307e-04, 8.4476e-04, 4.0426e-04,  ..., 4.2015e-04, 1.7830e-04,
         1.2833e-04],
        [3.9271e-04, 4.8325e-04, 9.5238e-04,  ..., 1.5126e-04, 1.3420e-04,
         3.9260e-04],
        [1.9278e-04, 2.6530e-04, 8.6903e-04,  ..., 1.6985e-05, 9.5103e-05,
         1.9610e-04],
        ...,
        [1.8257e-05, 3.1304e-04, 4.6398e-04,  ..., 2.7327e-04, 1.1909e-04,
         1.5069e-04],
        [1.7577e-04, 3.4820e-05, 9.4168e-04,  ..., 3.2848e-04, 2.2514e-04,
         5.4275e-05],
        [4.2916e-04, 1.6155e-04, 9.3186e-04,  ..., 1.0950e-04, 2.5083e-04,
         3.7374e-06]], device='cuda:0', grad_fn=<AbsBackward>)

In the forward你的MSE_loss,您定义一个线性层probably仍在CPU中(你没有提供MCVE https://stackoverflow.com/help/minimal-reproducible-example,所以我只能假设):

self.linear = nn.Linear(output1.size()[0], 1)

如果您想尝试看看这是否是问题所在,您可以:

self.linear = nn.Linear(output1.size()[0], 1).cuda()

然而,如果self.d是在CPU中,那么它会再次失败。要解决此问题,您可以将线性移动到与self.d张量通过这样做:

def forward(self, output1, output2, labels):
    self.labels = labels         
    self.linear = nn.Linear(output1.size()[0], 1)

    if self.metric == 'cos':
        self.d = F.cosine_similarity(output1, output2)
    elif self.metric == 'l1':
        self.d = torch.abs(output1-output2)
    elif self.metric == 'l2':
        self.d = torch.sqrt((output1-output2)**2)

    # move self.linear to the correct device
    self.linear = self.linear.to(self.d.device)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

预期设备类型为 cuda 的对象,但在 Pytorch 中获得了设备类型 cpu 的相关文章

随机推荐

  • 使用 Azure AD B2C 登录 Xamarin Android 应用

    经过一周的研究可与 Azure AD B2C 一起使用 Xamarin 以 Android 平台 而不是 Xamarin Forms 为目标的身份验证原理后 我终于寻求一些建议 我有一个带有 登录 按钮的活动 我想通过按钮的触摸事件登录到
  • 蓝牙 LE:地址类型

    我正在研究 iBeacon 技术 但我找不到有关地址类型的特定问题的任何答案 我找到了解释地址类型的文档 蓝牙规范 但我似乎找不到如何在两种类型 公共和随机 之间进行选择 这是我发现它的一个例子 它是由 Raspberry PI 上的 iB
  • React Native:如何在组件中添加脚本标签

    我正在尝试在 React Native 应用程序的组件内添加标签 下面是我的代码 它似乎不起作用 谁能告诉我如何解决这个问题 import React Component from react import PropTypes from p
  • Tensorflow无法分配设备进行操作

    我正在尝试跑步NVidia 脸部生成器演示 https github com tkarras progressive growing of gans在我的电脑上 我使用的是 Windows 10 我已经下载了源代码 并尝试按照页面下方的步骤
  • WPF DataGrid DataBindingComplete 事件在哪里?

    数据绑定完成后 我需要采取一些操作 例如 根据其他一些单元格使某些单元格只读 在WinForm DataGridView中 我曾经在DataBindingComplete事件中执行此操作 但是 我在 WPF DataGrid 中找不到这样的
  • CouchDB 视图中的链接文档

    我很难理解 CouchDB链接文档 http wiki apache org couchdb Introduction to CouchDB views Linked documents特征 我有两个types存储在单个 CouchDB 数
  • asp.net mvc 3 中模糊的远程属性验证

    asp net mvc 3 中的内置远程属性会执行 onchange 验证 我希望它在模糊时验证 有没有办法自定义它 或者还有其他东西可以这样做 我确信这是一个非常普遍的需求 你可以设置默认值 http docs jquery com Pl
  • 如何从 PySpark 中某个表中找到的多个表中获取所有数据?

    我正在使用 pyspark SQL 我有一个包含三列的表 MAIN TABLE DATABASE NAME TABLE NAME SOURCE TYPE 我想从 DATABASE NAME 和 TABLE NAME 列中的主表下找到的实际数
  • libxml2 用缩进解析文档

    我正在尝试调试正在解析包含缩进的 xml 文档的代码 我正在尝试找出在 xmlReadMemory 函数上使用的正确参数 XML PARSE NOBLANKS 选项对以下方法调用有何作用 xmlReadMemory buffer data
  • 设计 GUI [关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 作为一个几乎没有 或没有 艺术倾向的开发人员 您将如何为应用程序设计 GUI 特别是 我正在考虑桌面应用程序 但任何与网络应用程序相关
  • matplotlib pyplot:子图大小

    如果我绘制如下所示的单个图 它将具有 x y 大小 import matplotlib pyplot as plt plt plot 1 2 1 2 但是 如果我在同一行中绘制 3 个子图 则每个子图的大小均为 x 3 y fig ax p
  • Javascript document.getElementsByClassName 返回未定义

    我有一个函数应该相当简单 并且应该在加载后完成 以减少初始加载时间 基本上我使用这段代码来获取类 prefImg 的所有元素并用它们做一些事情 但是在firebug中调试时 它说var divsList未定义 function popula
  • Cgo 生成的源无法在 MVC 上编译

    我有一个用 CGo 制作的共享库 它在 Linux 和 Android 上链接得很好 但是 当使用 Microsoft Visual Studio 2017 在 Windows 10 上进行编译时 出现以下错误 Microsoft R Pr
  • Typescript 泛型 - “扩展对象”毫无意义吗?最佳实践是什么?

    我注意到了 p 泛型通常是没有意义的 因为基本上 javascript 中的所有内容都是对象 大多数文字都是具有 toString 方法的对象 字符串是具有 length 属性等的对象 我更喜欢 p p 但很好奇其他人注意到了什么 我现在没
  • 警告:“沙箱”不在已知选项列表中,但仍传递给 Electron/Chromium

    我在用Linux Mint 20 and vscode 1 52 1 My xsession errors文件显示Warning sandbox is not in the list of known options but still p
  • Swift 和 Xcode - 如何创建自定义选项卡栏图标

    我正在用 Swift 编写的 Xcode 中处理一个选项卡式应用程序项目 Xcode 6 3 and Swift 1 2 我在使用自定义选项卡栏图标时遇到很多麻烦 我在 Photoshop CS6 中设计了一张图像 将其保存为 PNG 在
  • 具有相对 URL 的 CSS 图像有时相对于 IE 中的页面 URL

    我似乎发现 IE 有时会尝试使用相对 URL 加载 CSS 图像 相对于页面 url 而不是 CSS 文件 url 示例 有人加载此网址 https www main events com event 234 my awesome show
  • JavaScript 删除除一个之外的所有隐藏元素

    有人帮我找到了 JavaScript从提交中删除隐藏表单字段的代码 https stackoverflow com questions 7745191 javascript removing contents of form hidden
  • 内容长度标头与分块编码

    我正在尝试权衡设置的利弊Content LengthHTTP 标头与使用分块编码从我的服务器返回 可能 大文件的比较 使用持久连接需要其中之一来符合 HTTP 1 1 规范 我看到了的优点Content Length标头是 下载对话框可以显
  • 预期设备类型为 cuda 的对象,但在 Pytorch 中获得了设备类型 cpu

    我有以下计算损失函数的代码 class MSE loss nn Module metric L1 L2 norms or cosine similarity mode training or evaluation mode def init