为什么测试时一定要用DataParallel?

2024-03-31

在GPU上训练,num_gpus设置为1:

device_ids = list(range(num_gpus))
model = NestedUNet(opt.num_channel, 2).to(device)
model = nn.DataParallel(model, device_ids=device_ids)

CPU上测试:

model = NestedUNet_Purn2(opt.num_channel, 2).to(dev)
device_ids = list(range(num_gpus))
model = torch.nn.DataParallel(model, device_ids=device_ids)
model_old = torch.load(path, map_location=dev)
pretrained_dict = model_old.state_dict()
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

这样会得到正确的结果,但是当我删除时:

device_ids = list(range(num_gpus))
model = torch.nn.DataParallel(model, device_ids=device_ids)

结果是错误的。


nn.DataParallel包装模型,其中实际模型被分配给module属性。这也意味着状态字典中的键有一个module. prefix.

让我们看一个非常简化的版本,只有一个卷积来看看差异:

class NestedUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)

model = NestedUNet()

model.state_dict().keys() # => odict_keys(['conv1.weight', 'conv1.bias'])

# Wrap the model in DataParallel
model_dp = nn.DataParallel(model, device_ids=range(num_gpus))

model_dp.state_dict().keys() # => odict_keys(['module.conv1.weight', 'module.conv1.bias'])

你保存的状态字典nn.DataParallel与常规模型的状态不一致。您要将当前状态字典与加载状态字典合并,这意味着加载状态将被忽略,因为模型没有任何属于键的属性,而是留下随机初始化的模型。

为了避免犯这种错误,您不应该合并状态字典,而应该直接将其应用到模型,在这种情况下,如果键不匹配就会出现错误。

RuntimeError: Error(s) in loading state_dict for NestedUNet:
        Missing key(s) in state_dict: "conv1.weight", "conv1.bias".
        Unexpected key(s) in state_dict: "module.conv1.weight", "module.conv1.bias".

为了使您保存的状态字典兼容,您可以去掉module. prefix:

pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}
model.load_state_dict(pretrained_dict)

您还可以通过从以下位置解开模型来避免将来出现此问题nn.DataParallel在保存其状态之前,即保存model.module.state_dict()。因此,您始终可以先加载模型的状态,然后再决定将其放入nn.DataParallel如果您想使用多个 GPU。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

为什么测试时一定要用DataParallel? 的相关文章

  • 如何平衡 GAN 中生成器和判别器的性能?

    这是我第一次使用 GAN 我面临着判别器多次优于生成器的问题 我正在尝试重现PA模型来自本文 http openaccess thecvf com content ICCV 2017 papers Sajjadi EnhanceNet Si
  • Keras 自定义损失函数:形状为batch_size (y_true) 的变量

    在 Keras 中实现自定义损失函数时 我需要tf Variable与我的输入数据的批量大小的形状 y true y pred def custom loss y true y pred counter tf Variable tf zer
  • 在pytorch张量中过滤数据

    我有一个张量X like 0 1 0 5 1 0 0 1 2 0 我想实现一个名为的函数filter positive 它可以将正数据过滤成新的张量并返回原始张量的索引 例如 new tensor index filter positive
  • 一次热编码期间出现 RunTimeError

    我有一个数据集 其中类值以 1 步从 2 到 2 i e 2 1 0 1 2 其中 9 标识未标记的数据 使用一种热编码 self one hot encode labels 我收到以下错误 RuntimeError index 1 is
  • pytorch 中的 autograd 可以处理同一模块中层的重复使用吗?

    我有一层layer in an nn Module并在一次中使用两次或多次forward步 这个的输出layer稍后输入到相同的layer pytorch可以吗autograd正确计算该层权重的梯度 def forward x x self
  • 在 model.fit() 期间记录 Keras 中每个时期的计算时间

    我想比较不同模型之间的计算时间 在拟合期间 每个时期的计算时间被打印到控制台 Epoch 5 5 160000 160000 10s 我正在寻找一种方法来存储这些时间 其方式与模型指标类似 模型指标保存在每个时期并可通过历史对象获取 尝试以
  • Keras,训练模型后如何预测?

    我正在使用 reuters example 数据集 它运行良好 我的模型已经过训练 我阅读了有关如何保存模型的信息 以便稍后加载它以再次使用 但如何使用这个保存的模型来预测新文本呢 我用吗models predict 我必须以特殊方式准备这
  • Pytorch ValueError:优化器得到一个空参数列表

    当尝试创建神经网络并使用 Pytorch 对其进行优化时 我得到了 ValueError 优化器得到一个空参数列表 这是代码 import torch nn as nn import torch nn functional as F fro
  • 用于多输入图像的 VGG16 网络

    我正在尝试将 VGG16 网络用于多个输入图像 使用具有 2 个输入的简单 CNN 训练该模型给了我一个 acc 大约 50 这就是为什么我想使用 VGG16 这样的既定模型进行尝试 这是我尝试过的 imports from keras a
  • Keras 自定义损失函数:访问当前输入模式

    在 Keras 带有 Tensorflow 后端 中 当前输入模式可用于我的自定义损失函数吗 当前输入模式被定义为用于产生预测的输入向量 例如 请考虑以下情况 X train X test y train y test train test
  • Google Colab:为什么 CPU 比 TPU 快?

    我正在使用 Google colabTPU训练一个简单的Keras模型 删除分布式strategy并在CPU比TPU 这怎么可能 import timeit import os import tensorflow as tf from sk
  • 结合两个 CNN

    我想在 Keras 中将两个 CNN 合并为一个 我的意思是我希望神经网络拍摄两张图像并在单独的 CNN 中处理每一张图像 然后将它们连接在一起进入扁平化层并使用全连接层来做最后的工作 我做了什么 Start With First Bran
  • Tensorflow:提要字典错误:您必须为占位符张量提供值

    我有一个错误 我无法找出原因 这是代码 with tf Graph as default global step tf Variable 0 trainable False images tf placeholder tf float32
  • BERT 输出不确定

    BERT 输出是不确定的 当我输入相同的输入时 我希望输出值是确定性的 但我的 bert 模型的值正在变化 听起来很尴尬 同一个值返回两次 一次 也就是说 一旦出现另一个值 就会出现相同的值并重复 如何使输出具有确定性 让我展示我的代码片段
  • 是否有可能在每个训练步骤中获得目标函数值?

    在通常的 TensorFlow 训练循环中 例如 train op tf train AdamOptimizer minimize cross entropy with tf Session as sess for i in range n
  • Caffe,在层中设置自定义权重

    I have a network In one place I want to use concat As on this picture 不幸的是 该网络无法训练 为了理解为什么我想连续改变权重 这意味着 FC4096 中的所有值一开始都
  • Pytorch GPU 使用率低

    我正在尝试 pytorch 的例子https pytorch org tutorials beginner blitz cifar10 tutorial html https pytorch org tutorials beginner b
  • pytorch 的 IDE 自动完成

    我正在使用 Visual Studio 代码 最近尝试了风筝 这两者似乎都没有 pytorch 的自动完成功能 这些工具可以吗 如果没有 有人可以推荐一个可以的编辑器吗 谢谢你 使用Pycharmhttps www jetbrains co
  • 使用 Keras 的 ImageDataGenerator 预测单个图像

    我对深度学习很陌生 所以请原谅我这个可能很简单的问题 我训练了一个网络来分类positive and negative 为了简化图像生成和拟合过程 我使用了ImageDataGenerator和fit generator函数 如下图 imp
  • 当我想在电脑中加载该模型时,我可以在 colab bu 中加载我的深度模型,但我不能

    我在colab中通过keras 2 3 1和tensorflow 2 1 0训练了一个深度模型 我用JSON和Keras保存了我的模型 saveWeightPath content drive My Drive model info mod

随机推荐

  • 在SAX解析期间确定根元素

    我正在使用 SAX 来解析 XML 文件 假设我希望我的应用程序only处理带有根元素 animalList 的 XML 文件 如果根节点是其他节点 SAX 解析器应该终止解析 使用 DOM 你可以这样做 Element rootEleme
  • 在 Flutter 中使用 After Effects 文件

    我知道如何导出Rive在 Flutter 应用程序中使用的 Flare 文件 但我怎样才能import Adobe 后遗症文件到 Rive 我知道可以这样做Lottie但我无法弄清楚如何准确地做到这一点 您可以轻松导入 bodymovin
  • 如何使用 Jest 和 vue/test-utils 测试输入文件

    我想使用 Jest 和 vue test utils 测试文件上传器组件 我有这个 describe show progress bar of uploading file gt const wrapper mount FileUpload
  • 如何在 lldb 中创建和使用临时 NSRange?

    NSRange 只是一个 C 结构体 我想在 Xcode 的 lldb 中的断点处创建一个临时的 专门用于 NSArray 方法objectAtIndex inRange 这是行不通的 lldb expr NSRange tmpRange
  • 高效更新 Bokeh 中的图像图以实现交互式可视化

    我正在尝试使用 Bokeh 创建多维数组的不同切片的平滑交互式可视化 切片中的数据根据 用户交互而变化 因此每秒必须更新几次 我编写了一个 Bokeh 应用程序 其中包含几个小图像图 64x64 值 来显示切片的内容 以及在用户与应用程序交
  • 根据年份合并 data.frames 并填写缺失值

    我有两个 data frames 我想将它们合并在一起 第一个是 datess lt seq as Date 2005 01 01 as Date 2009 12 31 days sample lt data frame matrix nc
  • JavaFX:如何在不关注主窗口的情况下关闭子窗口

    我试图在一定时间后以编程方式关闭子窗口 这个子窗口的initOwner是与主舞台一起设置的 但是关闭这个子窗口后 主窗口就会获得焦点 有什么方法可以在不关注主窗口的情况下关闭子窗口 以编程方式 下面是我的问题的快速演示 我尝试了所有可能的方
  • OSX 上“没有名为 _scproxy 的模块”

    我使用的是预装 python 2 6 的 OSX 10 6 并且想通过 easy install 或 setup py 在下载的包中 安装 python 包 就我而言 我正在尝试安装 MySQLdb 在这两种情况下 我都会得到一个堆栈跟踪
  • 在处理其他事情时如何将一堆未提交的更改放在一边

    如果我有一堆未提交的更改 并且想在处理其他事情时将其放在一边 然后稍后 例如几天后 返回并继续工作 完成此任务最简单的工作流程是什么 到目前为止我只体验过 Mercurial 的基本功能 我通常的方法是使用克隆创建一个新分支 但可能有更好的
  • 自动布局问题 Xcode 8 [_SwiftValue nsli_superitem]

    将我的代码转换为 Swift 3 我发现了一个奇怪的问题 现在 2016 年 9 月 15 日 Xcode 8 公共版本 已经发布 转换代码后 我的应用程序崩溃了 没有明显的原因 自动布局有问题 日志显示如下 SwiftValue nsli
  • Android,在库项目中提供应用程序特定常量的最佳方式?

    我正在为许多 Android 应用程序创建一个库项目 这些应用程序都具有一些我希望包含在库项目中的通用功能 但库项目功能需要使用特定于应用程序的常量 所以我正在寻找一种方法来为库函数提供常量名称并允许每个应用程序定义它们 特定应用程序常量的
  • data.table 中的 Between 与 inrange

    In R s data table 什么时候应该选择 between and inrange 用于子集化操作 我已阅读帮助页面 between我仍然对这些差异感到摸不着头脑 library data table X data table a
  • 屏幕阅读器何时应该可以使用“隐藏”元素(为了可访问性,a11y)?

    我听到建议 hidden类不作为 hidden display none 但将其宽度和高度设置为 1 并使用剪切等 使元素看起来仍然存在于屏幕上 但内容不可见 但是 当我们使用 JavaScript 隐藏某些内容时 该元素的目的就已经完成
  • 查找 JUnit TestCase 中测试方法的数量

    有没有办法知道测试用例中测试方法的数量 我想做的是有一个测试用例来测试几种场景 对于所有这些我只会执行一次 data setUp 同样 我想在所有测试方法结束时进行一次清理 tearDown 我当前使用的方法是维护一个计数器来记录文件中存在
  • Angular 2 - 如何在组件中包含 javascript?

    我对 Angular 完全陌生 直接开始使用 Angular 2 Angular 的一大优点是我可以模块化网页的每个功能 组件有自己的 html 和样式表 但是他们自己的 javascript 文件呢 我怎样才能包含它自己的特定 javas
  • 使用 powershell 获取家庭网络上的设备名称及其 IP 地址

    这个问题是由于尝试管理我的家庭 WiFi 网络而产生的 我一直在尝试 get netipaddress ipconfig 和 nslookup exe 等命令 以下命令有点引导我到某个地方 但它没有我正在寻找的信息 Get NetIPAdd
  • Zend_Form_Element_MultiCheckbox:如何将一长串复选框显示为列?

    所以我正在使用Zend Form Element MultiCheckbox显示一长串复选框 如果我简单地echo元素 我得到很多由分隔的复选框 br 标签 我想找出一种方法来利用简单性Zend Form Element MultiChec
  • 允许 Django Rest Framework 序列化器字段名称中使用连字符

    鉴于我正在编写代码的 OpenAPI 规范需要在请求正文中使用连字符大小写 又名短横线大小写 变量名称 那么在使用 Django Rest Framework 时应如何处理 例如 一个请求POST thing创建一个具有以下主体的事物 ow
  • node.js - 将数据推送到客户端 - 只能连接一个客户端?

    我正在尝试创建一个服务器端解决方案 通过 node js 定期将数据推送到客户端 无客户端轮询 连接应该永久打开 每当服务器有新数据时 它就会将其推送到客户端 这是我的简单示例脚本 var sys require sys http requ
  • 为什么测试时一定要用DataParallel?

    在GPU上训练 num gpus设置为1 device ids list range num gpus model NestedUNet opt num channel 2 to device model nn DataParallel m