【Pytorch|Bug】解决 RuntimeError: Error(s) in loading state_dict for Network: size mismatch

2023-05-16

文章目录

  • 问题背景
  • 解决方法

问题背景

Github开源项目:https://github.com/zhang-tao-whu/e2ec

python train_net.py coco_finetune --bs 12 \
--type finetune --checkpoint data/model/model_coco.pth

报错如下:

loading annotations into memory...
Done (t=0.09s)
creating index...
index created!
load model: data/model/model_coco.pth
Traceback (most recent call last):
  File "train_net.py", line 67, in <module>
    main()
  File "train_net.py", line 64, in main
    train(network, cfg)
  File "train_net.py", line 40, in train
    begin_epoch = load_network(network, model_dir=args.checkpoint, strict=False)
  File "/root/autodl-tmp/e2ec/train/model_utils/utils.py", line 66, in load_network
    net.load_state_dict(net_weight, strict=strict)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Network:
        size mismatch for dla.ct_hm.2.weight: copying a param with shape torch.Size([80, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 256, 1, 1]).
        size mismatch for dla.ct_hm.2.bias: copying a param with shape torch.Size([80]) from checkpoint, the shape in current model is torch.Size([1]).

由于我自己的数据集类别只有1,而COCO数据集有80个类别,预训练模型中dla.ct_hm.2 参数的大小与我的不符,所以需要舍弃预训练模型中的这个参数的权重。

解决方法

e2ec/train/model_utils/utils.py 中修改:

def load_network(net, model_dir, strict=True, map_location=None):

    if not os.path.exists(model_dir):
        print(colored('WARNING: NO MODEL LOADED !!!', 'red'))
        return 0

    print('load model: {}'.format(model_dir))
    if map_location is None:
        pretrained_model = torch.load(model_dir, map_location={'cuda:0': 'cpu', 'cuda:1': 'cpu',
                                                               'cuda:2': 'cpu', 'cuda:3': 'cpu'})
    else:
        pretrained_model = torch.load(model_dir, map_location=map_location)
    if 'epoch' in pretrained_model.keys():
        epoch = pretrained_model['epoch'] + 1
    else:
        epoch = 0
    pretrained_model = pretrained_model['net']

    net_weight = net.state_dict()
    for key in net_weight.keys():
        net_weight.update({key: pretrained_model[key]})
    '''
	舍弃部分参数
	'''
    net_weight.pop("dla.ct_hm.2.weight")
    net_weight.pop("dla.ct_hm.2.bias")
    
    net.load_state_dict(net_weight, strict=strict)
    return epoch

注意:load_state_dict 中设置 strict=False 只对增加或删除部分层有用,对于在原来参数上改变维度大小的情况不适用。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【Pytorch|Bug】解决 RuntimeError: Error(s) in loading state_dict for Network: size mismatch 的相关文章

随机推荐