Pytorch替换model对象任意层的方法 - 知乎直接上代码 import torch from torch import nn from torchvision.models import alexnet # 核心函数,参考了torch.quantization.fuse_modules()的实现 def _set_module(model, submodule_key, module): tokens = …https://zhuanlan.zhihu.com/p/356273702
import torch
from torch import nn
from torchvision.models import alexnet
# 核心函数,参考了torch.quantization.fuse_modules()的实现
def _set_module(model, submodule_key, module):
tokens = submodule_key.split('.')
sub_tokens = tokens[:-1]
cur_mod = model
for s in sub_tokens:
cur_mod = getattr(cur_mod, s)
setattr(cur_mod, tokens[-1], module)
# 以AlexNet为例子
model = alexnet(pretrained=False)
# 打印原模型
print("原模型")
print(model)
# 打印每个层的名字,和当前配置
# 从而知道要改的层的名字
for module in model.named_modules():
print(module)
# 假设要换掉AlexNet前2个卷积层,将通道从64改成128,其余参数不变
# 定义新层
layer0 = nn.Conv2d(3, 128, (11, 11), (4, 4), (2,2))
layer1 = nn.Conv2d(128, 192, (5, 5), (1, 1), (2,2))
# 层的名字从上面19-20行的打印内容知道AlexNet前2个层的名字为 "features.0" 和 "features.3"
_set_module(model, 'features.0', layer0)
_set_module(model, 'features.3', layer1)
# 打印修改后的模型
print("新模型")
print(model)
# 推理试一下
img = torch.rand((1, 3, 224, 224))
model(img)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)