Pytorch转Caffe最简单方法

2023-05-16

 

由于需要移植模型到比特大陆,华为昇腾这些平台。他们基本都支持caffe的模型,对其他模型支持不太好。用其他方法pytorch转caffe不然就是绕道太多,不然就是很多坑。这里记录一个最简单的方法:

[作者环境: torch 1.2.0 torchvision 0.4.0 ]

pip install pytorch2caffe

import torch
import torchvision
from pytorch2caffe import pytorch2caffe
def SaveDemo():
    from torchvision.models import resnet

    name = 'resnet18'
    resnet18 = resnet.resnet18()
    resnet18.eval()
    dummy_input = torch.ones([1, 3, 224, 224])
    pytorch2caffe.trans_net(resnet18, dummy_input, name)
    pytorch2caffe.save_prototxt('{}.prototxt'.format(name))
    pytorch2caffe.save_caffemodel('{}.caffemodel'.format(name))


if __name__ == '__main__':
    SaveDemo()

如果你的模型中使用了avg_pool 使用这种写法:

x = F.avg_pool2d(x,7)

 

 

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch转Caffe最简单方法 的相关文章

随机推荐