tensorflow C++ 环境搭建及实战

2023-05-16


摘要: 最近在研究如何使用tensorflow c++ API调用tensorflow python环境下训练得到的网络模型文件。参考了很多博客,文档,一路上踩了很多坑,现将自己的方法步骤记录下来,希望能够帮到有需要的人!(本文默认读者对python环境下tensorflow的使用已经比较熟悉了)

方法简要梳理如下:

  1. 安装bazel,然后使用bazel编译tensorflow源码,产生我们需要的库文件。
  2. 在python环境下,使用tensorflow训练一个深度神经网络,本文以mnist为例。将训练好的模型和参数冻结在一个pb文件中。
  3. 在C++环境下,调用pb文件,对图片进行预测。最终结果如下图所示,程序成功识别到图片中的数字为1,且概率为0.95。

具体程序参考项目:

https://github.com/zhangcliff/tensorflow-c-mnist.git

1.安装bazel

echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | sudo tee /etc/apt/sources.list.d/bazel.list
curl https://bazel.build/bazel-release.pub.gpg | sudo apt-key add -
sudo apt-get update
sudo apt-get install bazel

2.tensorflow的下载。

本博文使用的tensorflow版本为1.4,其他版本的c++编译可能会有一些不一样。

git clone https://github.com/tensorflow/tensorflow.git

3.tensorflow的c++编译。

3.1 进入tensorflow文件夹中,首先进行项目配置。

./configure

下面我贴出在我的机器上各选项的选择:值得注意的是,如果我们要使用cuda和cudnn的话,一定要搞清楚自己机器上使用的cuda和cudnn的版本(尤其是cudnn),例如我使用的是cuda8.0和cudnn6.0.21。

3.2 使用bazel命令进行编译。编译的时间比较长,我在i3-4150cpu上编译了一个小时左右的时间。

bazel build --config=opt --config=cuda //tensorflow:libtensorflow_cc.so

如果没有显卡则使用如下命令进行编译

bazel build --config=opt //tensorflow:libtensorflow_cc.so

编译完成后,在bazel-bin/tensorflow中会生成两个我们需要的库文件:libtensorflow_cc.so 和 libtensorflow_framework.so。

在后面我们用C++调用tensorflow时需要链接这两个库文件。


4. 使用tensorflow C++ api调用图模型(.pb文件)。

tensorflow 编译好之后,我们使用tensorflow c++ api调用一个已经冻结的图模型(.pb文件)

具体程序参考项目:

https://github.com/zhangcliff/tensorflow-c-mnist.git

4.1 在python环境下生成一个图模型(.pb文件)

对于tensorflow,在Python环境下的使用是最方便的,tensorflow的python api也是最多最全面的。因此我们在python环境下,训练了一个深度神经网络模型,并将模型和参数都冻结在一个pb文件中。为后面使用C++ API调用这个pb文件做好准备。我们以经典的mnist为例。

数据处理与模型的训练,这里就不多说了(默认读者对python环境下tensorflow的使用已经比较熟悉)。这里要说的是pb文件的生成,使用一下代码:

  1. from tensorflow.python.framework.graph_util import convert_variables_to_constants
  2. graph = convert_variables_to_constants(sess, sess.graph_def, [ "softmax"])
  3. tf.train.write_graph(graph, 'models', 'model.pb',as_text= False)
其中,convert_variables_to_constants()函数将参数变量冻结在图模型中,其中第三个参数为网络输出tensor的名字(name)。因为我的网络输出是这样定义的:y_conv = tf.nn.softmax(logits,name='softmax'),所以我的第三个参数设置为['softmax']。

write_graph()函数生成.pb文件,第二个参数为生成pb文件的文件夹,第三个参数为pb文件的名字。

将上面三行代码加入到你的模型训练的python脚步中,最后便可以得到我们需要的pb文件。


4.2 c++环境下调用pb文件。

第一步,加载模型

  1. Session* session;
  2. Status status = NewSession(SessionOptions(), &session); //创建新会话Session
  3. string model_path= "model.pb";
  4. GraphDef graphdef; //Graph Definition for current model
  5. Status status_load = ReadBinaryProto(Env::Default(), model_path, &graphdef); //从pb文件中读取图模型;
  6. if (!status_load.ok()) {
  7. std:: cout << "ERROR: Loading model failed..." << model_path << std:: endl;
  8. std:: cout << status_load.ToString() << "\n";
  9. return -1;
  10. }
  11. Status status_create = session->Create(graphdef); //将模型导入会话Session中;
  12. if (!status_create.ok()) {
  13. std:: cout << "ERROR: Creating graph in session failed..." << status_create.ToString() << std:: endl;
  14. return -1;
  15. }
  16. cout << "Session successfully created."<< endl;

第二步,使用tensorflow API读取图片。

这里定义了两个函数用于从jpg文件中读取图片:

  1. Status ReadTensorFromImageFile(const string& file_name, const int input_height,
  2. const int input_width, const float input_mean,
  3. const float input_std,
  4. std:: vector<Tensor>* out_tensors)
  1. static Status ReadEntireFile(tensorflow::Env* env, const string& filename,
  2. Tensor* output)
ReadEntireFile()函数读取文件内容,并将其赋值给其第三个输入参数 Tensor* output,但是这个tensor并不能直接输入给刚才我们加载的模型,需要经过一定的预处理。

在ReadTensorFromImageFile()函数,建立一个会话session,在会话中对读取到的tensor进行预处理,例如,对tensor进行解码( DecodeJpeg),resize,归一化等等。


第三步,运行模型

  1. const Tensor& resized_tensor = resized_tensors[ 0];
  2. vector<tensorflow::Tensor> outputs;
  3. string output_node = "softmax";
  4. Status status_run = session->Run({{ "inputs", resized_tensor}}, {output_node}, {}, &outputs);
resized_tensors的类型是std::vector<Tensor>,是一个vector容器。其在程序中,是ReadTensorFromImageFile()函数的最后一个输入参数,对读取到的图片tensor进行预处理后便保存在这个容器中。

模型预测时使用的函数为session->Run({{"inputs", resized_tensor}}, {output_node}, {}, &outputs)。

值得注意的是,"inputs"是图模型输入tensor的名字(name),变量output_node保存的是图模型输出tensor的名字(name)。这两个名字(name)一定要与保存的图模型(.pb)文件中的名字一致,否则会报错。最后得到的输出tensor保存在容器outputs中。

如果你有一个pb文件,可是不知道它的输入输出tensor的名字,我们可以在python环境中使用API加载这个模型,然后将模型中的所有operation打印出来,第一项便是输入tensor,最后一项便是输出tensor。

  1. print(sess.graph.get_operations())
  2. print(sess.graph.get_operations()[ 0])
  3. print(sess.graph.get_operations()[ 1])
第四步,从模型输出tensor中获得各类别的概率。

  1. Tensor t = outputs[ 0]; // Fetch the first tensor
  2. auto tmap = t.tensor< float, 2>(); // Tensor Shape: [batch_size, target_class_num]
  3. int output_dim = t.shape().dim_size( 1); // Get the target_class_num from 1st dimension
  4. // Argmax: Get Final Prediction Label and Probability
  5. int output_class_id = -1;
  6. double output_prob = 0.0;
  7. for ( int j = 0; j < output_dim; j++)
  8. {
  9. std:: cout << "Class " << j << " prob:" << tmap( 0, j) << "," << std:: endl;
  10. if (tmap( 0, j) >= output_prob) {
  11. output_class_id = j;
  12. output_prob = tmap( 0, j);
  13. }
  14. }
  15. std:: cout << "Final class id: " << output_class_id << std:: endl;
  16. std:: cout << "Final class prob: " << output_prob << std:: endl;

4.3 使用cmake进行编译

本例我们建立一个cmake工程,再通过make生成一个可执行文件。首先我们建立一个文件夹取名tensorflow_mnist,在该文件夹下创建子文件夹lib,将刚才编译tensorflow 时产生的两个库文件(libtensorflow_cc.so,libtensorflow_framework.so)放入其中。调用pb文件进行预测的C++文件,取名为tf.cpp,放在tensorflow_mnist目录下。文件结构如下图所示。


下面给出我的CMakeLists.txt的文件内容

cmake_minimum_required (VERSION 2.8.8)
project (tf_example)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -std=c++11 -W")

link_directories(./lib)
include_directories(
   /home/zwx/tensorflow
   /home/zwx/tensorflow/bazel-genfiles
   /home/zwx/tensorflow/tensorflow/contrib/makefile/downloads/nsync/public
   /usr/local/include/eigen3
   /home/zwx/tensorflow/bazel-bin/tensorflow
   /home/zwx/tensorflow/tensorflow/contrib/makefile/gen/protobuf/include
   ) 
add_executable(tf_test  tf.cpp) 
target_link_libraries(tf_test tensorflow_cc tensorflow_framework)
最后进入build文件夹,对该工程进行编译:
cd build
cmake ..
make

不过,在make 这一步很大几率会报错,我将我碰到的几个问题和解决方法写在这里,仅供参考。


问题一: protobuf版本不对


解决方法:安装正确版本的protobuf。在ubuntu16.04,tensorflow1.4下,应该安装protobuf-3.4.0。

问题二:nsync_cv.h文件缺失

正常情况下,该文件应该在路径tensorflow/tensorflow/contrib/makefile/downloads/nsync/public里。如果出现这个问题,很可能是tensorflow/tensorflow/contrib/makefile/下没有downloads文件夹,可能是编译的时候网络不好,没有下载这个文件夹。

解决方法: 进入 tensorflow/tensorflow/contrib/makefile/ 文件夹下,找到脚步。然后回到tensorflow文件夹下,执行该脚本。

./tensorflow/contrib/makefile/download_dependencies.sh

下载完毕后,便会有downloads文件夹,缺失的文件便会包含在其中。


问题三:

解决方法:这个问题的造成原因和问题二是一样的,查看下载好的downloads文件夹,发现其中有一个文件夹为eigen,进入eigen文件夹执行以下命令。

mkdir build
cd build
cmake ..
make
sudo make install
安装完毕后,在usr/local/include目录下会出现eigen3文件夹。


4.4 运行可执行程序

make成功后,在build目录下会出现一个可执行文件tf_test。将一张28*28的数字图片也放在build路径下,文件名为digit.jpg,最后执行tf_test文件。

./tf_test digit.jpg
结果如下图所示:

从中我们可以看到该c++程序识别到digit.jpg图片为数字一,为1的概率为0.95。


总结:这篇博文主要介绍了如何从源码编译tensorflow c++ API,并且使用c++ API调用一个在python环境下已经训练好并冻结参数的模型文件(.pb文件),最终生成一个可执行文件tf_test。通过运行该文件,我们成功识别了手写体数字。

具体脚本参考项目:

https://github.com/zhangcliff/tensorflow-c-mnist.git






本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

tensorflow C++ 环境搭建及实战 的相关文章

  • 关于PiBOT使用的一些问题汇总--ing

    xff1a 多机通讯是按照教程设置环境变量ROS MASTER URI 初始化 pibot init env sh xff0c 使用rostopic已经能够查看 xff0c 但是主机PC无法启动launch 原因 xff1a 个人测试是需要
  • ubuntu 19.10系统解决E: 仓库 “http://ppa.launchpad.net/webupd8team/java/ubuntu eoan Release” 没有 Release 文件。

    在终端换源后遇到E 仓库 http ppa launchpad net webupd8team java ubuntu eoan Release 没有 Release 文件 问题 解决方法 xff1a 将对应的ppa删除即可 第一步 xff
  • 使用org-mode生成晨检报告

    原文地址 https lujun9972 github io blog 2020 04 10 使用org mode生成晨检报告 index html 我们设置了每天8点多自动进行调用一次晨检脚本 xff0c 该脚本会将检查的一些数据存入本地
  • 使用Pi-hole屏蔽广告

    原文地址 https www lujun9972 win blog 2020 12 05 使用pi hole屏蔽广告 index html 目录 获取Pi的对外IP地址安装Pi hole配置DNS配置拦截域名 获取Pi的对外IP地址 我们一
  • 笑话理解之Mature

    原文地址 https www lujun9972 win blog 2020 12 09 笑话理解之mature index html 目录 The difference between government bonds and men T
  • 笑话理解之Hearing

    原文地址 https www lujun9972 win blog 2020 12 09 笑话理解之hearing index html 目录 The Hearing Problem The Hearing Problem In a chu
  • Emacs 作为 MPD 客户端

    原文地址 https www lujun9972 win blog 2022 06 26 emacs 作为 mpd 客户端 index html 今天才知道 xff0c Emacs居然内置了一个 mpc el 可以将 Emacs 转换为 M
  • 编译SONiC交换机镜像(转,参考2)

    sonic buildimage 编译SONiC交换机镜像 描述 以下是关于如何为网络交换机构建 ONIE 兼容网络操作系统 xff08 NOS xff09 安装程序镜像的说明 xff0c 以及如何构建在NOS内运行的Docker镜像 请注
  • Emacs 作为 MPD 客户端

    原文地址 https www lujun9972 win blog 2022 06 26 emacs 作为 mpd 客户端 index html 今天才知道 xff0c Emacs居然内置了一个 mpc el 可以将 Emacs 转换为 M
  • 使用 calc 计算保险实际收益率

    原文地址 https www lujun9972 win blog 2022 08 10 使用 calc 计算保险实际收益率 index html 今天某银行的客户经理来推销一个 增额终身寿险 xff0c 号称是能锁定3 5 的收益率 具体
  • Emacs使用Deft管理笔记

    1 Deft介绍 Deft是一款写作和管理纯文本格式笔记的工具 通过它可以快速的查到或新建笔记 Deft的下载地址是Deft 也可以通过浏览或者拷贝git仓库 xff1a git clone git jblevins org git def
  • linux挂载samba文件系统的方法

    1 手工挂载 有两个命令可以用来手工挂载samba文件系统 xff0c 一个是mount xff0c 一个是smbmount 1 1 使用mount命令挂载 mount就是用于挂载文件系统的 xff0c SMB做为网络文件系统的一种 xff
  • DB2里面如何进行快速分页?就像mysql的limit之类的

    从百度知道里看到的 xff0c 记录下来以防忘记了 只查询前10行 fetch first 10 rows only SELECT SALE DATE SUM SALE MONEY AS SUM MONEY FROM SALE REPORT
  • linux时间与Windows时间不一致的解决

    转载至http goodluck1982 blog sohu com 138950694 html 一 首先要弄清几个概念 xff1a 1 系统时间 与 硬件时间 系统时间 一般说来就是我们执行 date命令看到的时间 xff0c linu
  • java list中删除元素用remove()报错的fail-fast机制原理以及解决方案

    java list中删除元素用remove 报错的fail fast机制原理以及解决方案 现在有一个list 有6个元素 xff0c 值分别是1 5 5 8 5 10 xff0c 现需要删除值为5的元素 第一种 import java ut
  • PDF 缩略图无法正常显示 解决办法

    先交代一下配置 xff1a win10 43 福晰阅读器 43 Adobe Acrobat DC 有时候PDF无法显示缩略图 xff0c 像下面这样子 提前设置好两个点 xff1a 1 从不显示缩略图的复选框的勾 xff0c 取消 xff0
  • xmanager7开启Xstart 连接远程ubuntu虚拟机

    在想要建立连接的ubuntu虚拟机上安装xterm xff0c 此处我是用的是ubuntu20 04 sudo apt install xterm span class token operator span y 然后打开xmanager7
  • C#窗体应用程序常用控件介绍

    下面图片列出了我目前常用的一些控件 xff1a 1 窗体Form 新建一个C 窗体应用程序 xff0c 默认都会有一个窗体控件 xff0c 窗体就是我们应用程序最大的那个窗口了 窗体常用的属性有 xff1a xff08 1 xff09 St
  • docker load 报 Error processing tar file unexpected EOF 解决

    43 43 echo 201904 0 dirty 20191029 021252 43 43 sed 39 s g 39 43 sudo LANG 61 C chroot fsroot docker tag docker database
  • 论文翻译-Defending Against Universal Attacks Through Selective Feature Regeneration

    CVPR2020 通过选择性特征再生防御通用攻击 有一段时间自己看的论文都没有把相应的翻译或者笔记整理成文档的形式了 xff0c 虽然在PDF上会有一些标注 xff0c 但是觉得还是成稿的形式会方便很长一段时间之后回过头继续看及时回顾起来

随机推荐

  • 2020年3月24日360内推笔试

    考试时长2个半小时 xff0c 笔试题分为三个部分 xff1a 1 20道逻辑选择题 xff08 包括图形找规律 数字找规律 小学奥数 xff09 2 40道基础选择题 xff08 包括计算机网络 C 43 43 补全代码 查看输入输出 p
  • 【CVPR2020】CNN-generated images are surprisingly easy to spot... for now(假图检测/CNN合成图像的检测/图像取证)

    CVPR2020 ORAL 这篇论文的思路很简单 xff0c 就是涉及到的数据收集和用到的GAN太多了 xff0c 一般人可能都没精力去搞 简要概括 xff0c 实验量大 xff0c 效果好 xff0c 抓人眼球哈哈 这篇论文涉及到的GAN
  • Classication of Time-Series Images Using Deep Convolutional Neural Networks[用深度卷积神经网络对时序图分类]

    今天要分享的论文是ICMV2017的一篇非常巧妙的论文 作者是 论文下载链接 xff1a https arxiv org abs 1710 00886 关于论文的源码下载链接 xff1a https sites google com sit
  • 吴恩达机器学习+deeplearning课程笔记----干货链接分享

    分享两个GitHub链接 xff0c 今天看到的 xff0c 超赞超赞不能更赞了 xff0c 答应我一定要去看好吗 不论是笔记还是github中分享的其它资源 xff0c 课程视频链接 xff0c PPT下载 xff0c 作业布置等都超棒
  • 英文写作经典指导书--学术写作必备

    以下书籍文章已整理PDF版上传至我的CSDN下载资源中 xff0c 链接 资源审核不过 尴尬 xff0c 要的朋友在博客下面留言好啦 xff0c 我看到后会发送到你邮箱哒 2018 7 23 我把压缩包上传到了百度网盘里 xff0c 需要的
  • 图像处理和机器学习有什么关系?

    一篇很不错的文章 xff0c 分享给博客的朋友们 作者 xff1a 许铁 巡洋舰科技 链接 xff1a https www zhihu com question 21665775 answer 281946017 来源 xff1a 知乎 著
  • 【CVPR 2018】Learning Rich Features for Image Manipulation Detection(图像篡改检测)

    今天来给大家分享一篇CVPR2018的论文 xff0c 检测图像的篡改区域 xff0c 用更快的R CNN网络定位图像被篡改的部分 xff0c 练就PS检测的火眼金睛 让PS痕迹无处可逃 这就将图像鉴伪 xff0c 图像取证这方面与深度学习
  • 国内免费汉语语料库-NLP

    自转载https www sohu com a 196504864 236505 xff08 一 xff09 国家语委 1国家语委现代汉语语料库http www cncorpus org 现代汉语通用平衡语料库现在重新开放网络查询了 重开后
  • 【模糊数学】模糊逻辑,隶属度,模糊逻辑应用,模糊推理过程

    update 下一篇博客我将会讲如何用这篇博客的模糊推理过程构建一个图像边缘检测的模糊推理系统 链接 xff1a https blog csdn net luolan9611 article details 94296622 这是我的一项大
  • FRR BGP 协议分析 5 -- 路由更新(2)

    处理NLRI 获取NLRI的报文长度 xff0c 填入nlris NLRI UPDATE xff0c 到现在为止nlris里面的4种类型 如果有的话 xff0c 已经全部填写到nlris数组结构体里面 然后我们遍历这个数组 xff0c 处理
  • 从输入 URL 到页面加载完成中间都经历了什么

    摘要 目录 1 chrome浏览器资源加载时序分析2 w3c提供的接口performance timing分析3 一个完整的URL 解析过程细分介绍3 1 缓存相关3 1 1 URL解析 3 2 网络相关3 2 1 DNS解析3 2 2 建
  • ALC5621声卡调试记录

    转载请注明出处 xff1a https blog csdn net luomin5417 article details 80731790 平台 imx6q 内核版本 linux 3 14 1 硬件连接 图 1 1 硬件连接 2 设备树修改
  • Pytorch转Caffe最简单方法

    由于需要移植模型到比特大陆 xff0c 华为昇腾这些平台 他们基本都支持caffe的模型 xff0c 对其他模型支持不太好 用其他方法pytorch转caffe不然就是绕道太多 xff0c 不然就是很多坑 这里记录一个最简单的方法 xff1
  • No module named ‘index‘ after install pyflann

    如题 xff0c 墙内没有 I have some problems installing pyflann in python 3 7 3 after execute pip install pyflann The installation
  • 多维 opencv Mat访问

    你看完这篇文章之后 xff0c 将学会以下知识 xff1a 二维 三维 四维等任意维度的Mat的常用建立方法 xff1b 任意维度Mat中值的索引 xff1b 以及一些Mat常用的操作 下面是对各维度矩阵的介绍 xff1a 注意 xff1a
  • H264/H265码流的编码码率设置

    一 什么是视频码率 xff1f 视频码率是视频数据 xff08 视频色彩量 亮度量 像素量 xff09 每秒输出的位数 一般用的单位是kbps 二 设置视频码率的必要性 在视频会议应用中 xff0c 视频质量和网络带宽占用是矛盾的 xff0
  • 检测图像失焦、偏色、亮度异常

    要求通过算法检测监控设备是否存在失焦 偏色 亮度异常等问题 问题本身不难 xff0c 在网上查看了一些资料 xff0c 自己也做了一些思考 xff0c 方法如下 xff1a 1 失焦检测 失焦的主要表现就是画面模糊 xff0c 衡量画面模糊
  • Jupyter 安装与使用

    最近由于项目需要 xff0c 开始学习python xff0c 然后发现一个非常有用的python交互式编辑器 xff0c 非常容易上手而且非常有用和实在 xff0c 本博文是对学习jupyter notebook的一个汇总和记录 xff0
  • C语言 print()函数 规则,格式 意思

    C语言 print 函数 规则 xff0c 格式 意思 C语言格式字符print 函数 printf后面的参数包括 格式控制字符串 和输出变量的列表 格式控制字符串 由格式控制字符和普通字符 其中前者以 开始加某一个特殊字符 比如 d为输出
  • tensorflow C++ 环境搭建及实战

    摘要 xff1a 最近在研究如何使用tensorflow c 43 43 API调用tensorflow python环境下训练得到的网络模型文件 参考了很多博客 xff0c 文档 xff0c 一路上踩了很多坑 xff0c 现将自己的方法步