【PaddlePaddle onnx】PaddlePaddle导出ONNX及模型可视化教程

2023-05-16

文章目录

  • 1 背景介绍
  • 2 实验环境
  • 3 paddle.onnx.export函数简介
  • 4 代码实操
    • 4.1 PaddlePaddle与ONNX模型导出
    • 4.2 ONNX正确性验证
    • 4.3 PaddlePaddle与ONNX的一致性检查
    • 4.4 多输入的情况
  • 5 ONNX模型可视化
  • 6 ir_version和opset_version修改
  • 7 致谢

原文来自于地平线开发者社区,未来会持续发布深度学习、板端部署的相关优质文章与视频,如果文章对您有帮助,麻烦给点个赞,如果您有兴趣一起学习,欢迎点个关注:寻找永不遗憾(CSDN用户名)

1 背景介绍

使用深度学习开源框架Pytorch训练完网络模型后,在部署之前通常需要进行格式转换,地平线工具链模型转换目前支持Caffe1.0和ONNX(opset_version=10/11 且 ir_version≤7)两种。ONNX(Open Neural Network Exchange)格式是一种常用的开源神经网络格式,被较多推理引擎支持,例如Pytorch、PaddlePaddle、TensorFlow等。本文将详细介绍如何将PaddlePaddle格式的模型导出到ONNX格式。

2 实验环境

本教程的实验环境如下:

Python库Version
paddlepaddle2.4.1
paddle2onnx1.0.5
onnx1.13.0
onnxruntime1.14.0

3 paddle.onnx.export函数简介

paddle.onnx.export函数可以将PaddlePaddle模型导出为ONNX模型,函数介绍如下,其中x_spec用于配置paddle.onnx.export的input_spec参数。

x_spec = paddle.static.InputSpec(shape=None, dtype='float32', name=None)
#shape:   声明维度信息,默认为 None
#dtype:   数据类型,默认为 float32
#name:    网络输入节点名称

paddle.onnx.export(layer, path, input_spec=[x_spec], opset_version=11, **configs)
#layer:          导出的Layer对象,即需要转换的网络模型
#path:           存储模型的路径前缀,导出后会自动添加后缀“.onnx”
#input_spec:     用于配置模型输入属性
#opset_version:  默认为9,请手动配置10或11

关于paddle.onnx.export的更多详细介绍,可以查阅PaddlePaddle的API文档:
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/onnx/export_cn.html

4 代码实操

4.1 PaddlePaddle与ONNX模型导出

以下代码展示了搭建一个简单分类模型并以PaddlePaddle和ONNX格式保存的过程。

import paddle
import paddle.nn as nn

class MyNet(nn.Layer):
    def __init__(self, num_classes=10):
        super(MyNet, self).__init__()
        self.num_classes = num_classes
        self.features = nn.Sequential(
            nn.Conv2D(in_channels=1, out_channels=2,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU())
        self.linear = nn.Sequential(nn.Linear(98, num_classes))
    def forward(self, inputs):
        x = self.features(inputs)
        x = paddle.flatten(x, 1)
        x = self.linear(x)
        return x

model = MyNet()

#准备输入数据
x_spec = paddle.static.InputSpec([1, 1, 7, 7], 'float32', 'input1')
#将模型以PaddlePaddle的格式保存,以验证和ONNX模型推理的一致性
paddle.jit.save(layer=model, path='./pd_model/pdmodel',
                input_spec=[x_spec])
#将模型导出为ONNX格式保存
paddle.onnx.export(layer=model, path='./model',
                   input_spec=[x_spec], opset_version=11)

4.2 ONNX正确性验证

可以用以下代码验证ONNX模型的正确性,会检查模型的版本,图的结构,节点及输入输出。若输出为 Check: None 则表示无报错信息,模型导出正确。

import onnx

onnx_model = onnx.load("./model.onnx")
check = onnx.checker.check_model(onnx_model)
print('Check: ', check)

4.3 PaddlePaddle与ONNX的一致性检查

可以使用以下代码检查导出的ONNX模型和原始的PaddlePaddle模型是否有相同的计算结果。

import numpy as np
import onnxruntime
import paddle

input1 = np.random.random((1, 1, 7, 7)).astype('float32')

ort_sess = onnxruntime.InferenceSession("./model.onnx")
ort_inputs = {ort_sess.get_inputs()[0].name: input1}
ort_outs = ort_sess.run(None, ort_inputs)

model = paddle.jit.load("./pd_model/pdmodel")
model.eval()
paddle_input = paddle.to_tensor(input1)
paddle_outs = model(paddle_input)

print(ort_outs[0])
print(paddle_outs.numpy())
np.testing.assert_allclose(tf_outs.numpy(), ort_outs[0], rtol=1e-03, atol=1e-05)
print("onnx model check finsh.")

4.4 多输入的情况

若您的模型存在多输入,则可参考下方代码保存成PaddlePaddle和ONNX格式。ONNX的正确性验证和PaddlePaddle与ONNX的一致性检查不再赘述,仿照上述代码编写即可。

import paddle
import paddle.nn as nn

class MyNet(nn.Layer):
    def __init__(self, num_classes=10):
        super(MyNet, self).__init__()
        self.num_classes = num_classes
        self.features_1 = nn.Sequential(
            nn.Conv2D(in_channels=1, out_channels=2,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU())
        self.features_2 = nn.Sequential(
        nn.Conv2D(in_channels=1, out_channels=2,
                  kernel_size=3, stride=1, padding=1),
        nn.ReLU())
        self.linear = nn.Sequential(nn.Linear(98, num_classes))
    def forward(self, inputs1, inputs2):
        x = self.features_1(inputs1)
        y = self.features_2(inputs2)
        z = paddle.concat((x, y), 1)
        z = paddle.flatten(z, 1)
        z = self.linear(z)
        return z

model = MyNet()

x_spec = paddle.static.InputSpec([1, 1, 7, 7], 'float32', 'input1')
y_spec = paddle.static.InputSpec([1, 1, 7, 7], 'float32', 'input2')
paddle.jit.save(layer=model, path='./pd_model/pdmodel',
                input_spec=[x_spec, y_spec])
paddle.onnx.export(layer=model, path='./model',
                   input_spec=[x_spec, y_spec], opset_version=11)

5 ONNX模型可视化

导出成ONNX模型后,可以使用开源可视化工具Netron来查看网络结构及相关配置信息。Netron的使用方式主要分为两种,一种是使用在线网页版 https://netron.app/ ,另一种是下载安装程序 https://github.com/lutzroeder/netron 。此教程中模型的可视化效果为:

6 ir_version和opset_version修改

地平线工具链支持的ONNX模型需要满足 opset_version=10/11 且 ir_version≤7,当拿到的ONNX模型不满足这两个要求时,可以修改代码重新导出,或者尝试编写脚本直接修改ONNX模型的对应属性,第二种方式的示例代码如下:

import onnx

model = onnx.load("./model.onnx")
model.ir_version = 6
model.opset_import[0].version = 10
onnx.save_model(model, "./model_version.onnx")

**注意:**高版本向低版本切换时可能会出现问题,这里只是一种可尝试的解决方案。
调整结束后,使用Netron可视化model_version.onnx,如下图所示:
在这里插入图片描述

此时,ONNX模型的ir_version=6,opset_version=10,满足地平线工具链的转换条件。

7 致谢

原文来自于地平线开发者社区,未来会持续发布深度学习、板端部署的相关优质文章与视频,如果文章对您有帮助,麻烦给点个赞,如果您有兴趣一起学习,欢迎点个关注:寻找永不遗憾(CSDN用户名)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【PaddlePaddle onnx】PaddlePaddle导出ONNX及模型可视化教程 的相关文章

  • 解压码

    BN00001 22kke BN00002 88cde BN00003 00ike BN00004 76cdb BN00005 09dbm BN00006 0mndc BN00007 cd78d BN00008 bdmf8 BN00009
  • 保险项目业务流程

    1 整个项目分为四分模块 xff1a 信息采集模块 信息验证 审批 生成合同 xff08 开单 xff09 信息采集模块 xff1a 包括购买保险产品 xff0c 客户个人信息 1 纸质文档给客户填写 xff0c 在回来录入系统 2 客户直
  • IDEA使用maven自定义archetype

    标题自定义archetype 在pom文件中添加archetype plugin span class token generics span class token punctuation lt span plugin span clas
  • 自定义Perperties文件内容读取

    新建properties文件放在resources目录下 properties文件内容 url span class token operator 61 span jdbc span class token operator span my
  • 使用CSS中的Hover控制显示子元素或者兄弟元素

    lt DOCTYPE html gt lt html lang 61 34 en 34 gt lt head gt lt meta charset 61 34 UTF 8 34 gt lt meta name 61 34 viewport
  • iphone表情显示问号_如何在iPhone上搜索特定的表情符号

    iphone表情显示问号 Most of us use emoji on our iPhone but until recently finding the right one has been tricky Luckily startin
  • maven项目中的jdbc连接步骤

    在maven项目pom xml中到入驱动包 xff08 以下是驱动包代码 xff09 lt dependencies gt lt https mvnrepository com artifact mysql mysql connector
  • executeUpdate()与executeQuery()的使用

    增 删 改 用executeUpdate xff08 xff09 返回值为int型 xff0c 表示被影响的行数 例子 查用executeQuery 返回的是一个集合 next xff08 xff09 表示 指针先下一行 xff0c 还有f
  • Access denied for user ''@'localhost' (using password: YES)错误解决方法

    远程登录被拒绝 xff0c 要改一个表数据的属性让他可以远程登录 解决方法如下 xff0c 执行命令 xff1a mysql gt use mysql mysql gt select host user from user 查看结果是不是r
  • leetcode部分数据库+sqlzoo练习题

    175 组合两个表 SQL架构 表1 Person 43 43 43 列名 类型 43 43 43 PersonId int FirstName varchar LastName varchar 43 43 43 PersonId 是上表主
  • ubuntu下手动安装gnome插件

    ubuntu下手动安装gnome插件 span class token comment 下载环境 span sudo apt span class token operator span span class token keyword g
  • 类和对象的理解

    类和对象的关系 是java中两个重要的概念 xff0c 简单一句话将就是 xff1a 类是对象的模板 xff0c 对象是类的实例 比如 xff1a 设计车的图纸是类 xff0c 然后比亚迪 本田 奔驰这些车 xff08 对象 xff09 都
  • java设计模式的几种体现方式

    1 单例模型 有时候在我的设计中 xff0c 所有的类只共享一个实例 xff0c 那么这时候就需要设计一个单实例的类 思路是将这个类构造器私有化 xff0c 这样外部就无法直接创建对象 xff0c 然后提供公有的静态方法 xff0c 让外部
  • springIOC使用xml装配JavaBean对象

    在一个maven工程下 xff0c 在pom xml中导入spring依赖和相关的配置 lt xml version 61 34 1 0 34 encoding 61 34 UTF 8 34 gt lt project xmlns 61 3
  • spring整合MyBatis代码

    Spring 整合 MyBatis 就是把Spring和MyBatis应用到同一个项目中 xff1b 其中MyBatis提供数据库相关的操作 xff0c 完成对象数据和关系数据的转换 xff1b Spring完成项目的管理 xff0c 通过
  • Servlet基础知识

    web应用程序的组成 xff1a 网页 xff1a 浏览器需要显示的内容 Web浏览器 xff1a 1 向Web服务器发出请求 2 解析网页 xff0c 渲染显示给用户 Web服务器 xff1a 1 提供Web服务 2 存放Web应用程序
  • 兆位和兆字节之间有什么区别?

    majcot Shutterstock 马约科特 快门 Despite the fact that they re similar words with similar abbreviations megabits Mb and megab
  • SSM(Spring + SpringMVC + MyBatis)环境搭建

    1 导入依赖 lt Spring上下文容器 gt lt dependency gt lt groupId gt org springframework lt groupId gt lt artifactId gt spring contex
  • 系统安全复习

    DOS DOS xff1a 拒绝服务攻击 xff0c 向目标主机某端口发送超过处理能力的数据包 xff0c 耗尽目标主机资源 xff0c 使其无定法响应正常的服务请求 xff0c 使目标系统停止响应甚至奔溃 DDOS DDOS xff1a
  • 物联网四层架构

    1 感知层 2 网络层 3 应用层 4 公共技术

随机推荐