如何从给定模型获取 Graph(或 GraphDef)?

2023-11-29

我有一个使用 Tensorflow 2 和 Keras 定义的大模型。 该模型在 Python 中运行良好。现在,我想将它导入到 C++ 项目中。

在我的 C++ 项目中,我使用TF_GraphImportGraphDef功能。 如果我准备的话效果很好*.pb使用以下代码创建文件:

    with open('load_model.pb', 'wb') as f:
        f.write(tf.compat.v1.get_default_graph().as_graph_def().SerializeToString())

我已经在使用 Tensorflow 1 (使用 tf.compat.v1.* 函数)编写的简单网络上尝试了此代码。效果很好。

现在我想将我的大模型(开头提到的,使用 Tensorflow 2 编写的)导出到 C++ 项目中。为此,我需要获得一个Graph or GraphDef我的模型中的对象。问题是:如何做到这一点?我没有找到任何属性或函数来获取它。

我也尝试过使用tf.saved_model.save(model, 'model')保存整个模型。它生成一个包含不同文件的目录,包括saved_model.pb文件。不幸的是,当我尝试使用 C++ 加载此文件时TF_GraphImportGraphDef函数,程序抛出异常。


生成的protocol buffers文件tf.saved_model.save不包含GraphDef消息,但是一个SavedModel。你可以遍历那个SavedModel在Python中获取其中嵌入的图形,但这不会立即作为冻结图形工作,因此正确处理可能会很困难。相反,C++ API 现在包含一个LoadSavedModel调用允许您从目录加载整个保存的模型。它应该看起来像这样:

#include <iostream>
#include <...>  // Add necessary TF include directives

using namespace std;
using namespace tensorflow;

int main()
{
    // Path to saved model directory
    const string export_dir = "...";
    // Load model
    Status s;
    SavedModelBundle bundle;
    SessionOptions session_options;
    RunOptions run_options;
    s = LoadSavedModel(session_options, run_options, export_dir,
                       // default "serve" tag set by tf.saved_model.save
                       {"serve"}, &bundle));
    if (!.ok())
    {
        cerr << "Could not load model: " << s.error_message() << endl;
        return -1;
    }
    // Model is loaded
    // ...
    return 0;
}

从这里开始,您可以做不同的事情。也许您会最舒服地将保存的模型转换为冻结图,使用FreezeSavedModel,这应该允许您像以前一样做事情:

GraphDef frozen_graph_def;
std::unordered_set<string> inputs;
std::unordered_set<string> outputs;
s = FreezeSavedModel(bundle, &frozen_graph_def,
                     &inputs, &outputs));
if (!s.ok())
{
    cerr << "Could not freeze model: " << s.error_message() << endl;
    return -1;
}

否则,您可以直接使用保存的模型对象:

// Default "serving_default" signature name set by tf.saved_model_save
const SignatureDef& signature_def = bundle.GetSignatures().at("serving_default");
// Get input and output names (different from layer names)
// Key is input and output layer names
const string input_name = signature_def.inputs().at("my_input").name();
const string output_name = signature_def.inputs().at("my_output").name();
// Run model
Tensor input = ...;
std::vector<Tensor> outputs;
s = bundle.session->Run({{input_name, input}}, {output_name}, {}, &outputs));
if (!s.ok())
{
    cerr << "Error running model: " << s.error_message() << endl;
    return -1;
}
// Get result
Tensor& output = outputs[0];
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何从给定模型获取 Graph(或 GraphDef)? 的相关文章

  • 如何在win32中使用GetSaveFileName保存文件?

    我编写此代码是为了获取 fileName 来保存我的文件 include stdafx h include
  • Geodjango距离查询未检索到正确的结果

    我正在尝试根据地理位置的接近程度来检索一些帖子 正如您在代码中看到的 我正在使用 GeoDjango 并且代码在视图中执行 问题是距离过滤器似乎被完全忽略了 当我检查查询集上的距离时 我得到了预期距离 1m 和 18km 但 18km 的帖
  • 使用 C# 中的 Google 地图 API 和 SSIS 包获取行驶距离

    更新 找到了谷歌距离矩阵并尝试相应地修改我的代码 我在这里收到无效参数错误 return new GeoLocation dstnc uri ToString catch return new GeoLocation 0 0 https 基
  • Matplotlib 中 x 轴标签的频率和旋转

    我在下面编写了一个简单的脚本来使用 matplotlib 生成图形 我想将 x tick 频率从每月增加到每周并轮换标签 我不知道从哪里开始 x 轴频率 我的旋转线产生错误 TypeError set xticks got an unexp
  • Jython 和 SAX 解析器:允许的实体不超过 64000 个?

    我做了一个简单的测试xml saxJython 中的解析器在处理大型 XML 文件 800 MB 时遇到以下错误 Traceback most recent call last File src project xmltools py li
  • 无法为 wsdl 文件创建服务引用

    I have wsdl文件和xsd我本地机器上的文件 我想在项目中添加服务引用 我没有网络服务 我只有wsdl file 我收到以下错误 The document was understood but it could not be pro
  • Python:IndexError:修改代码后列表索引超出范围

    我的代码应该提供以下格式的输出 我尝试修改代码 但我破坏了它 import pandas as pd from bs4 import BeautifulSoup as bs from selenium import webdriver im
  • 返回表示每组内最大值的索引的一系列数字位置

    考虑一下这个系列 np random seed 3 1415 s pd Series np random rand 100 pd MultiIndex from product list ABDCE list abcde One Two T
  • 为什么 f(i = -1, i = -1) 是未定义的行为?

    我正在读关于违反评估顺序 http en cppreference com w cpp language eval order 他们举了一个令我困惑的例子 1 如果标量对象上的副作用相对于同一标量对象上的另一个副作用是无序的 则行为未定义
  • Anaconda 无法导入 ssl 但 Python 可以

    Anaconda 3 Jupyter笔记本无法导入ssl 但使用Atom终端导入ssl没有问题 我尝试在 Jupyter 笔记本中导入 ssl 但出现以下错误 C ProgramData Anaconda3 lib ssl py in
  • SocketIO + Flask 检测断开连接

    我在这里有一个不同的问题 但意识到它可以简化为 如何检测客户端何时从页面断开连接 关闭其页面或单击链接 换句话说 套接字连接关闭 我想制作一个带有更新用户列表的聊天应用程序 并且我在 Python 上使用 Flask 当用户连接时 浏览器发
  • C 语言中的 Alpha 混合 2 RGBA 颜色[重复]

    这个问题在这里已经有答案了 可能的重复 如何快速进行阿尔法混合 https stackoverflow com questions 1102692 how to do alpha blend fast 对 2 个 RGBA 整数 颜色进行
  • “必须声明标量变量”错误[重复]

    这个问题在这里已经有答案了 必须声明标量变量 Id SqlConnection con new SqlConnection connectionstring con Open SqlCommand cmd new SqlCommand cm
  • 如果“嵌入式”SQL 2008 数据库文件不存在,如何创建它?

    我使用 C ADO Net 和在 Server Management Studio 中创建的嵌入式 MS SQL 2008 数据库文件 附加到 MS SQL 2008 Express 创建了一个数据库应用程序 有人可以向我指出一个资源 该资
  • 双击打开 ipython 笔记本

    相关文章 通过双击 osx 打开 ipython 笔记本 https stackoverflow com questions 16158893 open an ipython notebook via double click on osx
  • 如何提高环复杂度?

    对于具有大量决策语句 包括 if while for 语句 的方法 循环复杂度会很高 那么我们该如何改进呢 我正在处理一个大项目 我应该减少 CC gt 10 的方法的 CC 并且有很多方法都存在这个问题 下面我将列出一些例如我遇到的问题的
  • 将上下文管理器的动态可迭代链接到单个 with 语句

    我有一堆想要链接的上下文管理器 第一眼看上去 contextlib nested看起来是一个合适的解决方案 但是 此方法在文档中被标记为已弃用 该文档还指出最新的with声明直接允许这样做 自 2 7 版起已弃用 with 语句现在支持此
  • 多个对象以某种方式相互干扰[原始版本]

    我有一个神经网络 NN 当应用于单个数据集时 它可以完美地工作 但是 如果我想在一组数据上运行神经网络 然后创建一个新的神经网络实例以在不同的数据集 甚至再次同一组数据 上运行 那么新实例将产生完全错误的预测 例如 对 XOR 模式进行训练
  • C++ Boost ASIO 简单的周期性定时器?

    我想要一个非常简单的周期性计时器每 50 毫秒调用我的代码 我可以创建一个始终休眠 50 毫秒的线程 但这很痛苦 我可以开始研究用于制作计时器的 Linux API 但它不可移植 I d like使用升压 我只是不确定这是否可能 boost
  • 嵌入式二进制资源 - 如何枚举嵌入的图像文件?

    我按照中的说明进行操作这本书 http www apress com book view 9781430225492 关于资源等的章节 我不太明白的是 如何替换它 images Add new BitmapImage new Uri Ima

随机推荐

  • 有没有办法在应用程序设置中使用字典或 xml?

    我必须在应用程序设置中存储复杂类型 我认为将其存储为 XML 效果最好 问题是我不知道如何存储 XML 我更喜欢将其存储为托管 XML 而不是仅使用必须在每次访问时解析它的原始 XML 字符串 我设法设置了Type设置为 XDocument
  • 如何通过 Google Play 商店链接到 Android 应用程序

    我有一个免费版本的应用程序 我想通过单击应用程序上的购买按钮来链接应用程序商店 我该怎么做 我完全不知道 请帮助我编写一些代码 提前致谢 我所做的是 public void onCreate Bundle savedInstanceStat
  • Swift didSet 获取数组的索引

    假设我有一个数组 var intArray Int 1 2 3 4 5 didSet print index of value that was modified if I do intArray 2 10 里面可以写什么didSet为了打
  • 使用 CPBPressureTouch GestureRecognizer 检测敲击压力强度

    它与一个 UIButton 配合得很好 void viewDidLoad super viewDidLoad CPBPressureTouchGestureRecognizer recognizer CPBPressureTouchGest
  • unload 事件可以用来可靠地触发 ajax 请求吗?

    我需要一种方法来监视用户编辑会话 我正在审查的解决方案之一将让我使用unload发送 ajax 请求以通知服务器编辑会话结束的事件 看 监控用户会话以防止编辑冲突 我的 相当有限的 阅读unload事件表示附加到该处理程序的代码必须快速运行
  • 如何为 React JSX 编写定义文件

    我想为 Summernote jsx 编写一个自定义定义文件 这样我就不会找不到 react summernote 模块 我已经写了 declare var ReactSummernote JSX ElementClass declare
  • Android POST 请求不起作用

    我正在这样做 Override protected Void doInBackground String strings try String query username strings 0 duration strings 1 dist
  • 检查变量是否为空或已填充

    我有以下问题 序言程序 man thomas 2010 man leon 2011 man thomas 2012 man Man once man Man problem man thomas true i want only on tr
  • 如何获取元素的文本节点?

    div class title I am text node a class edit Edit a div 我希望获得 我是文本节点 不希望删除 编辑 标签 并且需要跨浏览器解决方案 var text title contents fil
  • html/php,上传的文件未存储在$_FILES中

    我有一个表单 用户可以在其中提交对象的描述 包括图像 并且有 JavaScript 为 1 对象描述添加一组附加输入 提交表单时 文件信息不会存储在 FILES 中 表单标签是
  • NoClassDefFoundError:android.support.v4.util.ArrayMap

    在 JellyBean 上出现此错误 01 11 18 26 52 030 E AndroidRuntime 16517 FATAL EXCEPTION main 01 11 18 26 52 030 E AndroidRuntime 16
  • Play 框架上公共文件夹外部的资产映射

    我们有大量图像需要存储在外部路径中 即播放应用程序文件夹之外 我们如何才能将其作为资产来播放 以便将其作为网络服务器进行流式传输 你可能已经看过 Play 的有关资产的文档 除了 Play 的标准资产之外 您还可以定义自己的资产 In co
  • 在 Orchard CMS 中使用 Document.cshtml 的替代方案

    我目前正在开发一个需要能够覆盖 document cshtml 文件的网站 以便我可以根据用户的当前位置应用特定的 CSS 类 我尝试使用 URL 替代方案 例如 文档 cshtml 文档 url AreaA cshtml 文档 url A
  • Osmdroid - 更改本地地图文件夹

    我使用的是离线版本的osmdroid 地图放置在sdcard osmdroid中 你知道如何更改文件路径吗 我一直在搜索他们处理 ZIP 文件的代码 但没有找到任何解决方案 以前有人遇到过这个问题吗 Thx 如果你下载了osmdroid的代
  • 计算最接近的首选十进制结果的双精度值

    设 N x 为有效位数最少的十进制数的值 使得 x 是double最接近数字值的值 Given double值 a 和 b 我们如何计算double最接近 N b N a 的值 E g If a and b are the double v
  • Meteor Up 部署失败,但应用程序在开发中运行良好

    我正在测试我的第一个 Meteor 应用程序的部署 并考虑使用 Meteor Up 经过一番折腾后 我设法跑了mup setup没有任何错误 一切看起来都很好 然而 运行mup deploy fails mup deploy Buildin
  • 如何将句子中带括号的单词大写[重复]

    这个问题在这里已经有答案了 我使用以下代码将句子中的每个单词大写 但无法将带有括号的单词大写 PHP代码
  • 如何在基本 R 图中仅将一个图例名称设为斜体?

    我想在情节中添加一个图例 其中只有一个图例名称是斜体的 我在用着plot 在基础 R 中 但是 我需要斜体行来包含变量数字 所以我使用bquote 我尝试过的方法不起作用 a lt 2 b lt 5 plot a b l1 lt bquot
  • 如何在Sqlite中将列值转换为行?

    我在 sqlite 中有一个表值 例如 我在这样的 sqlite 查询中获得价值 select from tablename where id 101 and id 102 and id 1 and id 18 101 Local Loca
  • 如何从给定模型获取 Graph(或 GraphDef)?

    我有一个使用 Tensorflow 2 和 Keras 定义的大模型 该模型在 Python 中运行良好 现在 我想将它导入到 C 项目中 在我的 C 项目中 我使用TF GraphImportGraphDef功能 如果我准备的话效果很好