pytorch加载与修改预训练模型
有时会希望用预训练的模型来fine-tune或是作为初始化(毕竟初始化权重真的玄学…),但是不需要其中某一些层,这时候就需要对加载的预训练模型做一些修改。
如果已经知道了模型的结构,这件事还是比较容易的,不知道的话我就不会了。
import torch
import torch.nn as nn
import torchvision
from torchstat import stat
class resnet18_new(nn.Module):
def __init__(self):
super(resnet18_new,self).__init__()
pretrained_net=torchvision.models.resnet18()
pretrained_layers=pretrained_net.children()
layers=list(pretrained_layers)[:-1]
self.net=nn.Sequential(*layers)
self.fc=nn.Linear(512,10)
def forward(self,x):
x=self.net(x)
x=x.view(x.shape[0],-1)
x=self.fc(x)
return x
if __name__=='__main__':
model=resnet18_new()
stat(model,(3,32,32))
首先我先用.children()获取了一下模型中所有的layer,这里说明一下.children()与.modules()的区别:
.children()只会按顺序获取__init()__里的所有layer,而.module()不但会获取__init()__里的所有layer,还会把nn.sequential里的所有layer也拆开获取,直到不能再拆为止。所以说定义时最好按照网络连接的顺序来…
获取了所有的layer之后,就可以把转成list,选择自己想要的层了,不用担心权重或者连接的问题,layer内部的结构和权重都能被保留,但是写在forward中的结构可能就无法保留了,尽量写在__init()__会容易再利用一些(这个事情后面我再验证吧,因为转成了list,我觉得应该就留不下来了)。
如果很清楚模型的结构,其实也可以用.modules()把想要的层的权重取出来(.state_dict()),然后再放到自己的模型里(.load_state_dict()),就是比较麻烦,但是结构肯定是对的。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)