【LibTorch】C++中部署TorchScript模型

2023-10-26

1. LibTorch安装

下载cuda版本为11.3的LibTorch安装包并解压即完成安装:

# If you need cpu support, please replace "cu113" with "cpu" in the URL below.
wget https://download.pytorch.org/libtorch/nightly/cu113/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip

如果需要cpu版本的安装包按照上面注释方法替换即可。

2. C++调用PyTorch模型

2.1 Python中保存tensor数据

保存tensor的函数:

def save_tensor(data_tensor: Tensor, name: str):
    print("[python] %s: "%name, data_tensor)
    f = io.BytesIO()
    torch.save(data_tensor, f, _use_new_zipfile_serialization=True)
    with open('/home/chenxin/peanut/DenseTNT/test_cpp/dense_tnt_infer/data/%s.pt'%name, "wb") as out_f:
        # Copy the BytesIO stream to the output file
        out_f.write(f.getbuffer())

2.2 C++中保存tensor数据

将多个tensor保存到一个文件,便于管理。

示例:

std::string save_path = "peanut/DenseTNT/test_cpp/dense_tnt_infer/data_atlas/";
torch::save({vector_data.unsqueeze(0),
			vector_mask.unsqueeze(0),
			traj_polyline_mask.unsqueeze(0).toType(torch::kBool),
			map_polyline_mask.unsqueeze(0).toType(torch::kBool)},
			save_path + "data.pt");

python通过_parameters依次读取c++保存的tensor:

device = torch.device("cuda:0")
save_path = "peanut/DenseTNT/test_cpp/dense_tnt_infer/data_atlas/"
input_data = torch.jit.load(save_path + "data.pt")
vector_data = input_data._parameters['0'].to(device)
vector_mask = input_data._parameters['1'].bool().to(device)
traj_polyline_mask = input_data._parameters['2'].bool().to(device)
map_polyline_mask = input_data._parameters['3'].bool().to(device)

2.3 C++加载tensor并调用模型

首先需要将PyTorch模型转换为C++支持的TorchScript模型,具体步骤可参考这里

C++调用TorchScript模型代码:

#include <torch/torch.h>
#include "torch/script.h"
#include <iostream>
#include <string>
#include <chrono>

std::vector<char> get_the_bytes(std::string filename)
{
  std::ifstream input(filename, std::ios::binary);
  std::vector<char> bytes(
      (std::istreambuf_iterator<char>(input)),
      (std::istreambuf_iterator<char>()));

  input.close();
  return bytes;
}

// 加载tensor数据
torch::Tensor GetTensor(const std::string &path)
{
  std::vector<char> f = get_the_bytes(path);
  torch::IValue x = torch::pickle_load(f);
  torch::Tensor my_tensor = x.toTensor();
  return my_tensor;
}
int main()
{
  torch::Device device(torch::kCPU);
  if (torch::cuda::is_available())
  {
    device = torch::Device(torch::kCUDA, 0);
  }

  // 读取推理用例数据
  torch::Tensor vector_data = GetTensor("test_cpp/dense_tnt_infer/data/vector_data.pt");
  torch::Tensor vector_mask = GetTensor("test_cpp/dense_tnt_infer/data/vector_mask.pt");
  torch::Tensor traj_polyline_mask = GetTensor("test_cpp/dense_tnt_infer/data/traj_polyline_mask.pt");
  torch::Tensor map_polyline_mask = GetTensor("test_cpp/dense_tnt_infer/data/map_polyline_mask.pt");
  torch::Tensor cent_point = GetTensor("test_cpp/dense_tnt_infer/data/cent_point.pt");

  torch::set_num_threads(1);
  std::vector<torch::jit::IValue> torch_inputs;
  torch_inputs.push_back(std::move(vector_data.to(device)));
  torch_inputs.push_back(std::move(vector_mask.to(device)));
  torch_inputs.push_back(std::move(traj_polyline_mask.to(device)));
  torch_inputs.push_back(std::move(map_polyline_mask.to(device)));
  torch_inputs.push_back(std::move(cent_point.to(device)));
  
  // 加载torch script模型
  torch::jit::script::Module torch_script_model = torch::jit::load("models.densetnt.1/model_save/model.16_script.bin", device);

  for (int i = 0; i < 100; ++i)
  {
    auto t1 = std::chrono::high_resolution_clock::now();
    auto torch_output_tuple = torch_script_model.forward(torch_inputs);
    auto t2 = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double, std::milli> ms_double = t2 - t1;
    std::cout << ms_double.count() << "ms\n";
    // std::cout << torch_output_tuple << std::endl;
  }
  return 0;
}

3. 编译执行C++推理用例

3.1 编写CMakeLists

需要安装cuda和cudnn才能进行cmake;

在C++推理代码同级目录下,创建文件 CMakeLists.txt,写入:

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(dense_tnt_infer)  # 调用模型的c++文件名称
set(CMAKE_PREFIX_PATH "/home/chenxin/libtorch")  # 这里填解压libtorch时的路径

find_package(Torch REQUIRED)

add_executable(${PROJECT_NAME} "dense_tnt_infer.cc")  # 调用模型的c++文件名称
target_link_libraries(${PROJECT_NAME} ${TORCH_LIBRARIES})
set_property(TARGET ${PROJECT_NAME} PROPERTY CXX_STANDARD 14)

3.2 编译并执行用例

CMakeLists.txt同级目录下执行命令:

$ mkdir build
$ cd build
$ cmake ..
$ make 
$ ./dense_tnt_infer.cc

输出:

...
...
12.7836ms
12.7527ms
13.0666ms
14.3305ms
13.804ms
14.1567ms
13.3143ms
13.0827ms
13.0853ms
13.5594ms

推理时长还需要通过修改代码再优化。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【LibTorch】C++中部署TorchScript模型 的相关文章

  • 无法在 QGLWidget 中设置所需的 OpenGL 版本

    我正在尝试在 Qt 4 8 2 中使用 QGLWidget 我注意到 QGLWidget 创建的默认上下文不显示 OpenGL 3 1 以上的任何输出 Qt wiki 有一个教程 http qt project org wiki How t
  • 如何使用C从http下载文件?

    最近几天我试图弄清楚如何从 URL 下载文件 这是我对套接字的第一个挑战 我用它来了解协议 所以我想在没有 cURL 库的情况下只用 C 语言来完成它 我搜索了很多 现在我可以打印页面的源代码 但我认为这与文件不同 我不必只将接收到的数据从
  • 与 MinGW 的静态和动态/共享链接

    我想从一个简单的链接用法开始来解释我的问题 假设有一个图书馆z它可以编译为共享库 libz dll D libs z shared libz dll 或静态库 libz a D libs z static libz a 让我想要链接它 然后
  • 如何使用不同的基本路径托管 Blazor WebAssembly 应用程序

    我有一个 Blazor Webassemble NET 托管应用程序 在我们托管它的服务器上 应用程序的基本路径将是mydomain com coolapp 因此 为了尝试让应用程序在服务器上正确呈现 我一直遵循本页 应用程序基本路径 部分
  • copy_from_user() 错误:目标大小太小

    我正在为内核模块编写 ioctl 处理程序 我想从用户空间复制数据 当我编译禁用优化的代码时 O0 gflags 编译器返回以下错误 include linux thread info h 136 17 error call to bad
  • 来自 double 的 static_cast 可以优化分配给 double 吗?

    我偶然发现了一个我认为不必要的功能 并且通常让我感到害怕 float coerceToFloat double x volatile float y static cast
  • Visual Studio 2013 调试器显示 std::string 的奇怪值

    我有一个大型的 cmake 生成的解决方案 其中包含许多项目 由于某种原因 我无法查看字符串的内容 因为根据调试器 Bx Buf含有一些垃圾 text c str 正确返回 Hello 该问题不仅仅发生在本地字符串上 返回的函数std st
  • 如何在 C# 中以编程方式将行添加到 DataGrid?

    正如标题所述 我正在尝试使用 C 以编程方式将行添加到 DataGrid 但我似乎无法使其工作 这是我到目前为止所拥有的 I have a DataGrid declared as dg in the XAML foreach string
  • 用于 C++ 中图像分析的 OpenCV 二进制图像掩模

    我正在尝试分析一些图像 这些图像的外部周围有很多噪声 但内部有一个清晰的圆形中心 中心是我感兴趣的部分 但外部噪声正在影响我对图像的二进制阈值处理 为了忽略噪音 我尝试设置一个已知中心位置和半径的圆形蒙版 从而使该圆之外的所有像素都更改为黑
  • 从图像创建半透明光标

    是否可以从图像创建光标并使其半透明 我目前正在拍摄自定义图像并覆盖鼠标光标图像 如果我可以将其设为半透明 那就太好了 但不是必需的 销售人员喜欢闪亮的 目前正在做这样的事情 Image cursorImage customImage Get
  • DateTime.ParseExact - 为什么 yy 变成 2015 而不是 1915

    为什么 NET 假定以下年份是 2015 年 而不是 1915 年 var d DateTime ParseExact 20 11 15 dd MM yy new CultureInfo en GB 我想 它会尝试接近 但其背后是否有合理的
  • 使用任一默认捕获模式时,这是通过复制捕获还是 (*this) 通过引用捕获?是一样的吗?

    当我看到以下工作时我有点困惑 struct A void g void f g 但后来我发现this https stackoverflow com a 16323119 5825294答案非常详细地解释了它是如何工作的 本质上 它归结为t
  • 如何在Windows窗体中打开进程

    我想在我的 Windows 窗体应用程序中打开进程 例如 我希望当用户按下 Windows 窗体容器之一中的按钮时 mstsc exe 将打开 如果他按下按钮 它将在另一个容器上打开 IE DllImport user32 dll SetL
  • 如何在VS2005中使用从.bat而不是.exe启动的外部程序进行调试?

    在我的 c 项目的调试属性中 我选择了 启动外部程序 并选择了我希望将调试器附加到的程序的 exe 但是 现在我需要从 bat 文件而不是 exe 启动程序 但 VS2005 似乎不允许这样做 这可能吗 编辑 为了澄清 我需要调试从 bat
  • 在 clang 中向量化函数

    我正在尝试根据此用 clang 对以下函数进行矢量化铿锵参考 http llvm org docs Vectorizers html 它采用字节数组向量并根据以下条件应用掩码this RFC https www rfc editor org
  • 让 Windows 尝试读取文件

    我正在对 Windows 文件系统进行某种封装 当用户请求打开文件时 Windows 调用我的驱动程序来提供数据 在正常操作中 驱动程序返回缓存的文件内容 但是 在某些情况下 实际文件没有缓存 我需要从网络下载它 问题是是否有可能让 Win
  • 如何将模型绑定到动态创建的类 nancyfx

    首先感谢任何愿意查看我的问题的人 我对 Nancyfx 还很陌生 在尝试将 JSON 有效负载绑定到动态创建的类时遇到问题 我按照这篇文章中的代码动态创建了该类 在C 中动态创建一个类 https stackoverflow com que
  • 将同步 zip 操作转换为异步

    我们有一个现有的库 其中一些方法需要转换为异步方法 但是我不确定如何使用以下方法执行此操作 错误处理已被删除 该方法的目的是压缩文件并将其保存到磁盘 请注意 zip 类不公开任何异步方法 public static bool ZipAndS
  • Adobe Illustrator 中的折线简化如何工作?

    我正在开发一个记录笔划的应用程序 您可以使用定点设备来绘制笔划 在上图中 我绘制了一个笔划 其中包含 453 个数据点 我的目标是大幅减少数据点的数量 同时仍然保持原始笔画的形状 对于那些感兴趣的人 上图笔画的坐标可以作为GitHub 上的
  • NHibernate:无状态会话错误消息无法获取代理

    我正在使用 nHibernate 无状态会话来获取对象 更新一个属性并将对象保存回数据库 我不断收到错误消息 无状态会话无法获取代理 我在其他地方有类似的代码 所以我不明白为什么这不起作用 有谁知道问题可能是什么 我正在尝试更新Screen

随机推荐

  • 队列的基本运算实现

    队列 queue 队列是一种先进先出 first in first out FIFO 的线性表 它只允许在表的一端 队尾 rear 插入元素 而在另一端 队头 front 删除元素 插入操作称为入队或进队 删除操作称为出队或离队 队列示意图
  • matlab读jpg有三个通道,图像为“灰度图像”

    最近用matlab读取 灰度图 jpg格式 居然有三个通道 且灰度值还不一样 那么这是为什么呢 1 灰度图 其实是 灰度图 概念的问题 并不是灰色的图片就是灰度图 正常来说灰度图是某个波段的成像 是由ccd对该波段对应波长的光线的强度感应形
  • 健康保健产品爬虫:Python爬虫获取保健品信息和用户评价

    目录 第一部分 选择目标网站 第二部分 分析网站结构和查询参数
  • 经纬恒润与辉羲智能达成战略投资与业务合作,加速产业智能化进程

    近日 经纬恒润战略投资辉羲智能 并与辉羲智能签署战略合作协议 双方将聚焦未来智慧出行 共同打造基于国产高性能SoC的自动驾驶量产解决方案 助力客户快速实现包括轻地图城市NOA在内的高阶自动驾驶功能量产落地 目前自动驾驶规模化处于全面爆发前夕
  • Nightingale滴滴夜莺监控系统入门(三)--页面功能说明

    Nightingale滴滴夜莺监控系统入门 三 功能模块 V3 4 1 用户资源中心 资产管理系统 任务执行中心 监控告警系统 监控看图 监控大盘 告警策略 部署客户端 生产环境开放服务端端口 部署客户端 这章节主要是介绍夜莺的功能使用 各
  • k8s的安装

    我这里使用vmware创建了三台虚拟机 k8s的虚拟机建议最少2核 4G内存 我的电脑配置不高采用的2核 3G的配置 安装k8s之前需要先安装docker docker的安装参考 docker的安装及使用 docker的安装和使用 骑士99
  • ubuntu16.4虚拟机开机,进入tty1命令终端,无法进入桌面问题始末

    现象 1 ubuntu虚拟机开机频繁出现error failed to start network manager 2 进入tty1 vm login 分析 1 回想到前一天编译工程 由于 lib i386 linux gnu下缺少libu
  • visual studio code导入自定义模块(pycharm中能够运行的文件,vs code报错:未找到指定模块)

    一 先看下目录结构 二 在main py中导入Utils中的模块 直接导入即可 from Utils custom event parse import CustomEventParse 三 在custom event parse py中导
  • [运维

    READS SQL DATA 是 MySQL 存储过程和函数中的一种权限修饰符 用于标识该存储过程或函数只读取数据库的数据而不修改它 这个修饰符通常用于声明存储过程或函数的权限 以告知数据库管理系统该过程或函数不会对数据库进行写操作 从而允
  • 动手学数据分析 Task5

    动手学数据分析 Task5 一 逻辑回归 二 随机森林 三 模型评估 3 1 k折交叉验证 3 2 混淆矩阵 3 3 ROC曲线 一 逻辑回归 LogisticRegression penalty l2 dual False tol 0 0
  • 如何将Zookeeper和Kafka的log4j升级到2.16

    1 删除lib下的jar文件 对于kafka lib 删除 slf4j api 1 7 25 jar slf4j log4j12 1 7 25 jar log4j 1 2 17 jar 对于zk lib 删除 log4j 1 2 17 ja
  • 毕业设计 - stm32单片机的远程WIFI密码锁 - 物联网 嵌入式

    文章目录 0 前言 1 简介 主要器件 实现效果 4 硬件设计 WIFI模块 OLED显示屏 相关原理图 硬件接线 5 软件说明 开发环境介绍 程序下载配置 设备初始化打印的信息 6 部分核心代码 7 最后 0 前言 这两年开始毕业设计和毕
  • kubernetes集群-Master节点升级-kubeadm,kubectl,kubelet升级

    kubernetes Master单节点升级 kubeadm 升级 kubelet 升级 kubectl 升级 生产环境注意事项 由于 kubeadm upgrade 不会升级 etcd 请确保已对其进行了备份 例如 您可以使用 etcdc
  • java setsession_Java Session.setServerAliveInterval方法代码示例

    import com jcraft jsch Session 导入方法依赖的package包 类 private Session startNewSession boolean acquireChannel throws JSchExcep
  • 华为od机试 Java【跳房子2】

    题目 有若干个连续的方格地板 儿童们喜欢在上面玩游戏 在这个游戏中 玩家需要在三个回合内 按照规定的步数 从第一格跳到最后一格 跳到最后的玩家有机会选择一个他们喜欢的房子 直到所有的房子都被选完 当然 游戏中最多房子的人是胜者 但游戏并不那
  • 快速浏览Swift-笔记

    快速浏览Swift 笔记 快速浏览Swift https docs swift org swift book GuidedTour GuidedTour html 变量也常量 多行字符串 使用 let quotation I said I
  • python文件工程化,隐藏源码

    python文件工程化 隐藏源码 py文件转换为pyc文件 全文来自博客https www cnblogs com HByang p 13223118 html pyc介绍 pyc是一种二进制文件 是由py文件经过编译后 生成的文件 是一种
  • 3 个 C 程序示例,用于创建包含数据的文件

    本教程介绍如何使用 C 程序创建文件 在这些示例中 我们将创建新的 HTML 文件并向其中写入一些内容 文件的内容会有所不同 但这三个 C 示例程序应该向大家说明如何使用 fopen fprintf 等 c 文件函数来创建和操作文件 示例一
  • ibm中间键服务器缺少文件夹,存储中间件-MQ常见问题解决方法FAQ.doc

    存储中间件 MQ常见问题解决方法FAQ IBM Websphere MQ FAQ Last Release 2006 1 2 这里整理了IBM Websphere MQ的一些常见错误和解决方法 当发现MQ错误而一时无法解决时 可以参阅这里的
  • 【LibTorch】C++中部署TorchScript模型

    文章目录 1 LibTorch安装 2 C 调用PyTorch模型 2 1 Python中保存tensor数据 2 2 C 中保存tensor数据 2 3 C 加载tensor并调用模型 3 编译执行C 推理用例 3 1 编写CMakeLi