由于需要移植模型到比特大陆,华为昇腾这些平台。他们基本都支持caffe的模型,对其他模型支持不太好。用其他方法pytorch转caffe不然就是绕道太多,不然就是很多坑。这里记录一个最简单的方法:
[作者环境: torch 1.2.0 torchvision 0.4.0 ]
pip install pytorch2caffe
import torch
import torchvision
from pytorch2caffe import pytorch2caffe
def SaveDemo():
from torchvision.models import resnet
name = 'resnet18'
resnet18 = resnet.resnet18()
resnet18.eval()
dummy_input = torch.ones([1, 3, 224, 224])
pytorch2caffe.trans_net(resnet18, dummy_input, name)
pytorch2caffe.save_prototxt('{}.prototxt'.format(name))
pytorch2caffe.save_caffemodel('{}.caffemodel'.format(name))
if __name__ == '__main__':
SaveDemo()
如果你的模型中使用了avg_pool 使用这种写法:
x = F.avg_pool2d(x,7)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)