一、用pytorch加载预训练模型默认存储路径
本机CPU:C:/Users/asus/.cache/torch/hub/checkpoints
服务器GPU:/home/xxx/.cache/torch/hub/checkpoints
二、常用预训练模型下载地址
https://github.com/pytorch/vision/tree/master/torchvision/models
ResNet
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
VGG
model_urls = {
'vgg11': 'https://download.pytorch.org/models/vgg11-8a719046.pth',
'vgg13': 'https://download.pytorch.org/models/vgg13-19584684.pth',
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}
DenseNet
model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}
GoogLeNet
model_urls = {
# GoogLeNet ported from TensorFlow
'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
}
Inception
model_urls = {
# Inception v3 ported from TensorFlow
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth',
}
and so on
三、修改预训练模型下载路径
def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
pretrained_model = model_zoo.load_url(model_urls['resnet50'])
return model
找到 load_url,Ctrl + 鼠标左键,查找代码里所调用网络的类,查看此网络的加载方法。
def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None):
r"""Loads the Torch serialized object at the given URL.
If downloaded file is a zip file, it will be automatically
decompressed.
If the object is already present in `model_dir`, it's deserialized and
returned.
The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
`hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
model_dir便是模型存储的地址,默认值为None。因此,将model_dir修改成我们模型存储的路径即可。
大家可以提前把预训练模型下载到对应的文件夹,避免在线下载网络不佳。
def load_state_dict_from_url(url, model_dir='/d3/xxx/checkpoint', map_location=None, progress=True, check_hash=False, file_name=None):
r"""Loads the Torch serialized object at the given URL.
If downloaded file is a zip file, it will be automatically
decompressed.
If the object is already present in `model_dir`, it's deserialized and
returned.
The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
`hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
四、查看预训练模型的参数(ResNet50)
path = '/home/xxx/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth'
pretrained_dict = torch.load(path)
for k, v in pretrained_dict.items():
print(k)
显示结果:
conv1.weight
bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked
layer1.0.conv1.weight
layer1.0.bn1.weight
layer1.0.bn1.bias
...
...
...
layer4.2.bn3.running_var
layer4.2.bn3.num_batches_tracked
fc.weight
fc.bias
五、加载不包含全连接层的参数
if pretrained:
pretrained_model = model_zoo.load_url(model_urls['resnet50'])
state = model.state_dict()
for key in state.keys():
if key in pretrained_model.keys():
if "fc" not in key:
state[key] = pretrained_model[key]
model.load_state_dict(state)
待更新…