深度学习部署--tensorflow 用c++调用前向

2023-11-16

使用TensorFlow C++ API构建线上预测服务 - 第一篇

Oct 9, 2017 |  tensorflow

目前,TensorFlow官方推荐使用Bazel编译源码和安装,但许多公司常用的构建工具是CMake。TensorFlow官方并没有提供CMake的编译示例,但提供了MakeFile文件,所以可以直接使用make进行编译安装。另一方面,模型训练成功后,官方提供了TensorFlow Servering进行预测的托管,但这个方案过于复杂。对于许多机器学习团队来说,一般都有自己的一套模型托管和预测服务,如果使用TensorFlow Servering对现存业务的侵入性太大,使用TensorFlow C++ API来导入模型并提供预测服务能方便的嵌入大部分已有业务方案,对这些团队来说比较合适。

本文以一个简单网络介绍从线下训练到线上预测的整个流程,主要包括以下几点:

  • 使用Python接口训练模型
  • 使用make编译TensorFlow源码,得到静态库
  • 调用TensorFlow C++ API编写预测代码,使用CMake构建预测服务

使用Python接口训练模型

这里用一个简单的网络来介绍,主要目的是保存网络结构和参数,用于后续的预测。

 
     
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
 
     
import tensorflow as tf
import numpy as np
with tf.Session() as sess:
a = tf.Variable( 5.0, name= 'a')
b = tf.Variable( 6.0, name= 'b')
c = tf.multiply(a, b, name= 'c')
sess.run(tf.global_variables_initializer())
print(a.eval()) # 5.0
print(b.eval()) # 6.0
print(c.eval()) # 30.0
tf.train.write_graph(sess.graph_def, 'simple_model/', 'graph.pb', as_text= False)

这个网络有两个输入,a和b,输出是c,最后一行用来保存模型到simple_model目录。运行后会在simple_model目录下生成一个graph.pb的protobuf二进制文件,这个文件保存了网络的结构,由于这个例子里没有模型参数,所以没有保存checkpoint文件。

源码编译TensorFlow

官方详细介绍可以看这里源码编译TensorFlow。其实很简单,以maxOS为例,只要运行以下命令即可,其他操作系统也有相应的命令。编译过程大概需要半小时,成功后会在tensorflow/tensorflow/contrib/makefile/gen/lib下看到一个100多MB的libtensorflow-core.a库文件。maxOS需要使用build_all_linux.sh,并且只能用clang,因为有第三方依赖编译时把clang写死了。

 
     
1
2
3
 
     
git clone https://github.com/tensorflow/tensorflow.git
cd tensorflow
tensorflow/contrib/makefile/build_all_linux.sh

后续如果要依赖TensorFlow的头文件和静态库做开发,tensorflow/tensorflow/contrib/makefile目录下的几个目录需要注意:

  • downloads 存放第三方依赖的一些头文件和静态库,比如nsync、Eigen等
  • gen 存放TensorFlow生成的C++ PB头文件、TensorFlow的静态库、ProtoBuf的头文件和静态库等等

使用TensorFlow C++ API编写预测代码

预测代码主要包括以下几个步骤:

  • 创建Session
  • 导入之前生成的模型
  • 将模型设置到创建的Session里
  • 设置模型输入输出,调用Session的Run做预测
  • 关闭Session

创建Session

 
     
1
2
3
4
5
6
7
 
     
Session* session;
Status status = NewSession(SessionOptions(), &session);
if (!status.ok()) {
std:: cout << status.ToString() << std:: endl;
} else {
std:: cout << "Session created successfully" << std:: endl;
}

导入模型

 
     
1
2
3
4
5
6
7
 
     
GraphDef graph_def;
Status status = ReadBinaryProto(Env::Default(), "../demo/simple_model/graph.pb", &graph_def);
if (!status.ok()) {
std:: cout << status.ToString() << std:: endl;
} else {
std:: cout << "Load graph protobuf successfully" << std:: endl;
}

将模型设置到创建的Session里

 
     
1
2
3
4
5
6
 
     
Status status = session->Create(graph_def);
if (!status.ok()) {
std:: cout << status.ToString() << std:: endl;
} else {
std:: cout << "Add graph to session successfully" << std:: endl;
}

设置模型输入

模型的输入输出都是Tensor或Sparse Tensor。

 
     
1
2
3
4
5
 
     
Tensor a(DT_FLOAT, TensorShape()); // input a
a.scalar< float>()() = 3.0;
Tensor b(DT_FLOAT, TensorShape()); // input b
b.scalar< float>()() = 2.0;

预测

 
     
1
2
3
4
5
6
7
8
9
10
11
12
13
 
     
std:: vector< std::pair< string, tensorflow::Tensor>> inputs = {
{ "a", a },
{ "b", b },
}; // input
std:: vector<tensorflow::Tensor> outputs; // output
Statuc status = session->Run(inputs, { "c"}, {}, &outputs);
if (!status.ok()) {
std:: cout << status.ToString() << std:: endl;
} else {
std:: cout << "Run session successfully" << std:: endl;
}

查看预测结果

 
     
1
2
 
     
auto c = outputs[ 0].scalar< float>();
std:: cout << "output value: " << c() << std:: endl;

关闭Session

 
     
1
 
     
session->Close();

完整的代码在https://github.com/formath/tensorflow-predictor-cpp,路径为src/simple_model.cc

使用CMake构建预测代码

这里主要的问题是头文件和静态库的路径要正确,包括TensorFlow以及第三方依赖。 以macOS为例,其他平台路径会不一样。

头文件路径

 
     
1
2
3
4
5
 
     
tensorflow // TensorFlow头文件
tensorflow /tensorflow/contrib /makefile/gen /proto // TensorFlow PB文件生成的pb.h头文件
tensorflow /tensorflow/contrib /makefile/gen /protobuf-host/include // ProtoBuf头文件
tensorflow /tensorflow/contrib /makefile/downloads /eigen // eigen头文件
tensorflow /tensorflow/contrib /makefile/downloads /nsync/ public // nsync头文件

静态库路径

 
     
1
2
3
 
     
tensorflow/tensorflow/contrib/makefile/gen/ lib // TensorFlow静态库
/tensorflow/tensorflow/contrib/makefile/gen/protobuf-host/ lib // protobuf静态库
/tensorflow/tensorflow/contrib/makefile/downloads/nsync/builds/default.macos.c++ 11 / / nsync静态库

编译时需要这些静态库

 
     
1
2
3
4
 
     
libtensorflow-core.a
libprotobuf.a
libnsync.a
其他: pthread m z

CMake构建

 
     
1
2
3
4
5
 
     
git clone https://github.com/formath/tensorflow-predictor-cpp.git
cd tensorflow-predictor-cpp
mkdir build && cd build
cmake ..
make

构建完成后在bin路径下会看到一个simple_model可执行文件,运行./simple_model即可看到输出output value: 6。 需要注意的时,编译选项里一定要加这些-undefined dynamic_lookup -all_load,否则在编译和运行时会报错,原因可见dynamic_lookupError issues

以上用c = a * b一个简单的网络来介绍整个流程,只要简单的修改即可应用到复杂模型中去,更复杂的一个例子可见src/deep_model.cc

参考

url:http://mathmach.com/2017/10/09/tensorflow_c++_api_prediction_first/

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习部署--tensorflow 用c++调用前向 的相关文章

  • 提高 React 项目整洁度的 21 个最佳实践

    React 在如何组织结构方面非常开放 这正是为什么我们有责任保持项目的整洁和可维护性 今天 我们将讨论一些改善 React 应用程序健康状况的最佳实践 这些规则被广泛接受 因此 掌握这些知识至关重要 所有内容都将以代码展示 所以做好准备
  • 端口扫描技术

    端口扫描 常见的扫描类型 全连接扫描 TCP connect 扫描 半连接扫描 TCP SYN 扫描 IP 头信息 dumb 扫描 秘密扫描 TCP FIN 扫描 TCP ACK 扫描 NULL 扫描 XMAS 扫描 SYN ACK 扫描

随机推荐

  • SQL编程:存储过程、触发器、函数(实例基于MySQL5.7.12)

    SQL编程基础 A 编程环境 即存储过程 触发器和函数中进行SQL编程 所以有些语法并不能应用于普通的SQL应用场景 如命令行直接SQL查询 B 变量声明 1 全局变量 声明 set 变量名 值 读取 select 变量名 赋值 set 变
  • 联想gen系列服务器,Hpe Microserver Gen10 Plus开箱

    Hpe Microserver Gen10 Plus开箱 2021 04 19 10 53 23 25点赞 69收藏 83评论 心水很久的gen10 plus终于到了 关注了很久终于下手了 在值得买好像都没看到gen10 plus的开箱 那
  • vuex的持久化插件

    目的 让在vuex中管理的状态数据同时存储在本地 可免去自己存储的环节 在开发的过程中 像用户信息 名字 头像 token 需要vuex中存储且需要本地存储 再例如 购物车如果需要未登录状态下也支持 如果管理在vuex中页需要存储在本地 1
  • 参考《一个64位操作系统的设计与实现》,自己写操作系统(一)

    1 安装VMware虚拟机 版本16 下载地址 http downdownxia com down VMware16lsb rar key fa4505a42b82aa65195be879fc84defd 2 安装centos系统 版本6
  • 【项目实战】MySQL查询计算某个字符在某个字段中的数量

    一 背景说明 表sys dept中有个字段 ancestors ancestors的值是含有 逗号 现在需要计算 逗号 这个字符串在ancestors中出现的数量 二 解决方案 SELECT dept id dept name ancest
  • 2023华为od机试 Python【最长公共后缀】

    题目 我们现在要实现一个功能找到字符串数组 中的最长公共后缀如果不存在公共后缀 输入描述 abc bbc c 输出描述 c 示例1 输入 abc bbc c 输出 c 说明 返回公共后缀 c 示例2 输入 aa bb cc 输出 Zero
  • 解决微信小程序报错:[渲染层网络层错误] Failed to load local image resource

    一 场景 写了一个图片点击 全屏展示的组件 页面图片 gt 点击 gt 打开全屏遮罩层显示大图片 1控制元素展示的变量 data photoShow false 2图片点击函数 onClick const url null e curren
  • Shell的read 读取控制台输入、read的使用

    文章目录 1 read 读取控制台输入 1 1基本语法 1 2read的使用 如果想看更详细的Shell总结请到我之前写的博客https blog csdn net Redamancy06 article details 126048299
  • com.sun.org.apache.xerces.internal.impl.io.MalformedByteSequenceException: Invalid byte 2 of 2-byte

    com sun org apache xerces internal impl io MalformedByteSequenceException Invalid byte 2 of 2 byte UTF 8 sequence 分析 这个问
  • YOLO-----关于正负样本、Loss、IOU、怎样去平衡正负样本的问题?

    关于正负样本 Loss IOU 怎样去平衡正负样本的问题 1 关于正负样本 2 Loss计算 3 IOU GIOU DIOU CIOU 4 怎样去平衡正负样本的问题 先整理一下anchor的概念 常用的anchor定义 Faster R C
  • MySQL 8 安装教程

    MySQL 8发布了 据说相比MySQL 5速度提升了2倍 今天来搞一搞MySQL 8 一 下载MySQL 8 1 首先当然是下载安装包了 下载地址 点击下载MySQL 8 这个页面相信大家都熟悉 我就不多说了 2 将下载的压缩包解压 解压
  • 全网最简洁的mpy-cross教程

    大家知道我一向精干 不喜欢搞花儿的 如果去mpy官网看mpy cross的相关资料 估计又得绕蒙 跟我来 保证你三分钟学会 但是本文不涉及原理 第一 mpy cross是干嘛滴 答 把py文件转成mpy系统读的mpy文件 术语咱不懂 叫交叉
  • H3C交换机如何配置SNMP协议?

    1 使用telnet 登陆设备 system view snmp agent snmp agent community read public snmp agent sys infoversion all dis cur save 保存 Y
  • 操作系统原理大题

    一 地址变换和求FAT表大小 某一页表内容自0 7依次为03 07 0B 11 1A 1D 20 22 请计算页面大小为1K和4K时的逻辑地址134D对应的物理地址 首先 将134D转换为二进制数为 0001001101001101 1k为
  • 【2024届校招内推:NTAA84y】腾讯云智研发中心

    云智校招新官网查看最新岗位情况 云智研发中心2024届校园招聘官网 内推码 NTAA84y 云智研发公司2024届校园招聘启动啦 腾讯旗下子公司 八大类岗位 五大城市全面开放 在喜欢的城市 做喜欢的工作 期待正能量 共担当 实干家的你加入云
  • dumpsys meminfo 的原理和应用

    什么是dumpsys meminfo Android中通过命令dumpsys meminfo package name pid 查看指定进程的内存使用情况 通过输出的信息 可以看出来应用在内存哪里分配出现了问题 比如native heap
  • 华为服务器sn号查询网站,linux 查询服务器sn

    linux 查询服务器sn 内容精选 换一换 Linux云服务器变更规格时 可能会发生磁盘挂载失败的情况 因此 变更规格后 需检查磁盘挂载状态是否正常 本节操作介绍变更规格后检查磁盘挂载状态的操作步骤 以root用户登录云服务器 执行以下命
  • top 命令

    NAME top display Linux tasks SYNOPSIS top hv abcHimMsS d delay n iterations p pid pid a 按内存使用排序 b 批处理 c 显示完整的命令 d 指定间隔时间
  • 文章目录 定义 抽象类型定义 存储结构 顺序存储 定长顺序存储结构 堆式顺序存储结构 链式存储 串的链式存储结构 定义 串是一种内容受限的线性表 串 字符串 由零个或多个字符组成的有限序列 子串 串的任意个连续的字符组成的子序列 主串 包含
  • 深度学习部署--tensorflow 用c++调用前向

    使用TensorFlow C API构建线上预测服务 第一篇 Oct 9 2017 tensorflow 文章目录 1 使用Python接口训练模型 2 源码编译TensorFlow 3 使用TensorFlow C API编写预测代码 3