使用libtorch调用EfficientNet模型(pt文件)

2023-11-13

1.首先确定自己电脑上的pytorch版本,然后下载合适的libtorch版本。
使用libtorch调用c++接口,要保证下载的libtorch的版本和pytorch的版本对应

至少使用低版本的pytorch和高版本的libtorch是没法成功的。反过来是可以的:即高版本的pytorch和低版本的libtorch。

各个版本的libtorch下载参考地址

2.训练
训练的时候推荐只保存模型参数,我之前是保存整个模型但是一值转换不成功。
训练完后会保存一个xxx.pth文件,然后就可以转换这个文件了。

3.转换
需要将训练好的xxx.pth文件转换成xxx.pt文件。直接贴代码(每个人的需求不一样,根据自己的需求改)

from PIL import Image
import torch
from torchvision import transforms,models
import torch.nn as nn
from efficientnet_pytorch import EfficientNet
from torch.autograd import Variable

def Transfer_cup_model():
 	#这个代码使用的前提是需要在你的虚拟环境中下载好EfficientNet的包 ( pip install efficientnet_pytorch)
    model = EfficientNet.from_name('efficientnet-b5')
    model.set_swish(memory_efficient=False)
    num_ftrs = model._fc.in_features
    model._fc = nn.Linear(num_ftrs, 6)   #我训练的是6类,对应着自己的类别数改
    # model.load_state_dict(torch.load('Garbage/each_model/epoch_29.pth'))
   
    model.load_state_dict(
        {k.replace('module.', ''): v for k, v in torch.load('Garbage/each_model/epoch_29.pth').items()})

    model.eval()
    model.cpu()
    example = torch.randn(1, 3, 456, 456)
    with torch.no_grad():
        traced_script_module = torch.jit.trace(model, example)
        traced_script_module.save('model-cpu.pt')

def Transfer_GPU_model():
    model = EfficientNet.from_name('efficientnet-b5')
    model.set_swish(memory_efficient=False)
    num_ftrs = model._fc.in_features
    model._fc = nn.Linear(num_ftrs, 6)
    # model.load_state_dict(torch.load('Garbage/each_model/epoch_29.pth'))
    #我是用四个GPU并行训练的,需要加这一句,如果是单GPU可以用上面的一句
    model.load_state_dict(
        {k.replace('module.', ''): v for k, v in torch.load('Garbage/each_model/epoch_29.pth').items()})

    model.eval()
    # model = torch.load('Garbage/epoch_29.pth')
    # model = torch.nn.DataParallel(model).cuda()
    if torch.cuda.is_available():
        model.cuda()

    # 转libtorch
    example = torch.randn(1, 3, 456, 456).cuda()
    with torch.no_grad():
        traced_script_module = torch.jit.trace(model, example)
        traced_script_module.save('model-gpu.pt')

def Transfer_onnx_model():
    model = EfficientNet.from_name('efficientnet-b5')
    model.set_swish(memory_efficient=False)
    num_ftrs = model._fc.in_features
    model._fc = nn.Linear(num_ftrs, 6)
    # model.load_state_dict(torch.load('Garbage/each_model/epoch_29.pth'))
    model.load_state_dict(
        {k.replace('module.', ''): v for k, v in torch.load('Garbage/each_model/epoch_29.pth').items()})

    model.eval()
    # model = torch.load('Garbage/epoch_29.pth')
    # model = torch.nn.DataParallel(model).cuda()
    if torch.cuda.is_available():
        model.cuda()

    #转onnx
    dummy_input1 = torch.randn(1, 3, 456, 456).cuda()
    torch.onnx.export(model,dummy_input1,"Efficient.onnx",export_params=True, verbose=True, training=False)


def Transfer_type(Type):
    if Type == 0:
        Transfer_cup_model()   #cpu模型
    elif Type == 1:
        Transfer_GPU_model()   #Gpu模型
    elif Type == 2:
        Transfer_onnx_model()  #onnx模型(GPU)
    else:
        print("The model of this type cann't transfer")

def main():
    Transfer_type(1)

if __name__ == '__main__':
    main()

4.调用
网上由很多教程,写的很详细 参考
其实就是调用下载好的libtorch库,跟调用opencv库一样,VS中包含目录、库目录、依赖项 添加好就行(不清楚添加libtorch的哪些包都,都添加进去也行)。
数据预处理和模型使用可以参考这篇博客: 数据预处理及模型使用

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用libtorch调用EfficientNet模型(pt文件) 的相关文章

随机推荐

  • Vue+elementUI el-input输入框手机号校验

    1 限制input框内只能输入数字 且为11位 type number 数字类型 maxlength属性对type number 类型的输入框无效 ninput if value length gt 11 value value slice
  • 达梦数据库教程:docker安装DM8数据库

    安装前准备 软硬件 版本 终端 X86 64 架构 Docker 2023 年 6 月版 下载 Docker 安装包 请在达梦数据库官网下载 Docker 安装包 导入安装包 拷贝安装包到 opt 目录下 执行以下命令导入安装包 docke
  • windows下nginx的安装及使用

    1 下载nginx http nginx org en download html 下载稳定版本 以nginx Windows 1 12 2为例 直接下载 nginx 1 12 2 zip 下载后解压 解压后如下 2 启动nginx 有很多
  • 为什么寄存器比内存快?

    原文出处 Mike Ash 译文出处 阮一峰 计算机的存储层次 memory hierarchy 之中 寄存器 register 最快 内存其次 最慢的是硬盘 同样都是晶体管存储设备 为什么寄存器比内存快呢 Mike Ash写了一篇很好的解
  • Vue使用routerlink实现点击导航栏进行页面跳转

    实现内容 如图所示 要实现的是 点击导航栏中的Data Set Data Mining Result List Model List区域跳转至对应界面 使用router link来实现跳转 1 如代码所示 router link后面的to需
  • linux笔记--文件内容操作和历史命令

    目录 cat命令 more命令 less命令 head命令 tail命令 sed命令 vim编辑器 history命令 clear命令 cat命令 查看文件内容 标准输出 补充 1 标准输出 在linux中规定为输出到屏幕 2 标准输入 在
  • 如何用cin读入空格

    在我们使用cin读入字符时 默认是跳过中间的空格以及可能的制表符和换行符 那么 如何让其不跳过空格呢 我们可以使用操作符noskipws来实现 cin gt gt noskipws 设置cin读取空白符 char ch while cin
  • 激光SLAM7-基于已知位姿的构图算法

    1 通过覆盖栅格建图算法进行栅格地图的构建 1 1 Theory 1 2 code 这里没有判断idx和hitPtIndex是否有效 start of TODO 对对应的map的cell信息进行更新 1 2 3题内容 GridIndex h
  • 服务器的相关知识

    服务器的分类 服务器指一个管理资源并为用户提供服务的计算机 通常分为文件服务器 数据库服务器和应用程序服务器 对于普通PC来说 服务器在稳定性 安全性 性能等方面都要求更高 因此CPU 芯片组 内存 磁盘系统 网络等硬件和普通PC有所不同
  • 主动配电网SOCP_OPF学习笔记(4)配电网重构

    配电网中的开关一般可分为联络开关和分段开关 联络开关负责转供备用和网络结构优化 常开 分段开关用于连接两条线路段的开关 为常闭 通过改变这两种开关的状态来调整网络拓扑结构 称为网络重构 加入联络开关支路会形成弱环网 1 辐射状拓扑约束 为了
  • 开源的推荐系统简介TOP 10

    最近这两年推荐系统特别火 本文搜集整理了一些比较好的开源推荐系统 即有轻量级的适用于做研究的SVDFeature LibMF LibFM等 也有重 量级的适用于工业系统的 Mahout Oryx EasyRecd等 供大家参考 PS 这里的
  • 利用chatgpt快速初步学习pandas

    最近体验了chatgpt作为编程助手的功能 确实很厉害 只要你擅长提问 找答案很精准快捷 由此可以想到是否能够通过系列提问 快速上手一个工具 以pandas为例 开始提问学习 是什么 有什么用 我需要用python处理表格数据 给我推荐现在
  • tp1900芯片对比7621a_貌似很多人看不起MTK,其实MTK7621A已经很给力了。

    以下内容为转载 和大家一起学习一下MTK7621A相关知识 全球无线通讯及数字多媒体IC设计领导厂商联发科技股份有限公司 MediaTek Inc 2013年11月宣告推出面向802 11ac高端路由器的全新双核网络芯片MT7621A MT
  • Python中将图片用base64进行编码

    我们可以使用base64模块 通过base64 b64encode 函数将图片直接转换为base64编码 import base64 假设a目录下有123 jpg图片 with open a 123 jpg rb as f read f r
  • 面试指南之如何介绍做过的项目

    面试是每个程序员都逃不过的一环 在我面试过的程序员中 有一半的程序员都描述不好自己做过的项目 有些都讲不到3分钟就结束了 听完我都不知道这个项目是做什么的 所以 决定写下这遍手记 希望对正在找工作的你有所帮助 在面试过程中 程序员都需要介绍
  • java人脸识别_使用百度智能云的人工智能模块,让你的Java应用更加智能

    人工智能 前言 之前有在微头条简单介绍了一下过程 想了一下 还是觉得给详细分享一下干货才行 于是才有了这篇文章 百度智能云 百度智能云是百度即All in AI主题之后开发出来的产品 总体看来可以分为两部分 第一部分是百度机器学习BML 是
  • arch linux使用iptables

    一 安装 arch中已经编译安装了iptables 无需重新安装 二 启动 iptables启动时 会读取 etc iptables iptables rules中写的规则 而Arch默认不启动iptables服务 也不会创建这个文件 这个
  • Windows环境下,使用GnuWin32工具安装后缀为patch的补丁到C源码软件包

    在CMD命令行 cd到GnuWin32安装目录的bin下 在命令行输入 patch exe d WORK DIR i PATCH FILE p 0 l N WORK DIR 要打补丁的目录 PATCH FILE 补丁文件 p 0 直接使用补
  • 黑马程序员———类加载器

    Java培训 Android培训 iOS培训 Net培训 期待与您交流 Java虚拟机中可以安装多个类加载器 系统默认三个 主要类加载器 每个类负责加载特定位置的类 BootStrap ExtClassLoader AppClassLoad
  • 使用libtorch调用EfficientNet模型(pt文件)

    1 首先确定自己电脑上的pytorch版本 然后下载合适的libtorch版本 使用libtorch调用c 接口 要保证下载的libtorch的版本和pytorch的版本对应 至少使用低版本的pytorch和高版本的libtorch是没法成