nn.DataParallel
包装模型,其中实际模型被分配给module
属性。这也意味着状态字典中的键有一个module.
prefix.
让我们看一个非常简化的版本,只有一个卷积来看看差异:
class NestedUNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
model = NestedUNet()
model.state_dict().keys() # => odict_keys(['conv1.weight', 'conv1.bias'])
# Wrap the model in DataParallel
model_dp = nn.DataParallel(model, device_ids=range(num_gpus))
model_dp.state_dict().keys() # => odict_keys(['module.conv1.weight', 'module.conv1.bias'])
你保存的状态字典nn.DataParallel
与常规模型的状态不一致。您要将当前状态字典与加载状态字典合并,这意味着加载状态将被忽略,因为模型没有任何属于键的属性,而是留下随机初始化的模型。
为了避免犯这种错误,您不应该合并状态字典,而应该直接将其应用到模型,在这种情况下,如果键不匹配就会出现错误。
RuntimeError: Error(s) in loading state_dict for NestedUNet:
Missing key(s) in state_dict: "conv1.weight", "conv1.bias".
Unexpected key(s) in state_dict: "module.conv1.weight", "module.conv1.bias".
为了使您保存的状态字典兼容,您可以去掉module.
prefix:
pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}
model.load_state_dict(pretrained_dict)
您还可以通过从以下位置解开模型来避免将来出现此问题nn.DataParallel
在保存其状态之前,即保存model.module.state_dict()
。因此,您始终可以先加载模型的状态,然后再决定将其放入nn.DataParallel
如果您想使用多个 GPU。