将onnx的静态batch改为动态batch及修改输入输出层的名称

2023-11-19

背景

在模型的部署中,为了高效利用硬件算力,常常会需要将多个输入组成一个batch同时输入网络进行推理,这个batch的大小根据系统的负载或者摄像头的路数时刻在变化,因此网络的输入batch是在动态变化的。对于pytorch等框架来说,我们并不会感受到这个问题,因为整个网络在pytorch中都是动态的。而在实际的工程化部署中,为了运行效率,却并不能有这样的灵活性。可能会有人说,那我就把batch固定在一个最大值,然后输入实际的batch,这样实际上网络是以最大batch在推理的,浪费了算力。所以我们需要能支持动态的batch,能够根据输入的batch数来运行。

一个常见的训练到部署的路径是:pytorch→onnx→tensorrt。在pytorch导出onnx时,我们可以指定输出为动态的输入:

torch_out = torch.onnx.export(model, inp,
                              save_path,input_names=["data"],output_names=["fc1"],dynamic_axes={
        "data":{0:'batch_size'},"fc1":{0:'batch_size'}
    })

而另一些时候,我们部署的模型来源于他人或开源模型,已经失去了原始的pytorch模型,此时如果onnx是静态batch的,在移植到tensorrt时,其输入就为静态输入了。想要动态输入,就需要对onnx模型本身进行修改了。另一方面,算法工程师在导模型的时候,如果没有指定输入层输出层的名称,导出的模型的层名有时候可读性比较差,比如输出是batchnorm_274这类名称,为了方便维护,也有需要对onnx的输入输出层名称进行修改。

操作

修改输入输出层

def change_input_output_dim(model):
    # Use some symbolic name not used for any other dimension
    sym_batch_dim = "batch"

    # The following code changes the first dimension of every input to be batch-dim
    # Modify as appropriate ... note that this requires all inputs to
    # have the same batch_dim 
    inputs = model.graph.input
    for input in inputs:
        # Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
        # Add checks as needed.
        dim1 = input.type.tensor_type.shape.dim[0]
        # update dim to be a symbolic value
        dim1.dim_param = sym_batch_dim
        # or update it to be an actual value:
        # dim1.dim_value = actual_batch_dim
    
    outputs = model.graph.output
    for output in outputs:
        # Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
        # Add checks as needed.
        dim1 = output.type.tensor_type.shape.dim[0]
        # update dim to be a symbolic value
        dim1.dim_param = sym_batch_dim

model = onnx.load(onnx_path)
change_input_output_dim(model)

通过将输入层和输出层的shape的第一维修改为非数字,就可以将onnx模型改为动态batch。

修改输入输出层名称

def change_input_node_name(model, input_names):
    for i,input in enumerate(model.graph.input):
        input_name = input_names[i]
        for node in model.graph.node:
            for i, name in enumerate(node.input):
                if name == input.name:
                    node.input[i] = input_name
        input.name = input_name
        

def change_output_node_name(model, output_names):
    for i,output in enumerate(model.graph.output):
        output_name = output_names[i]
        for node in model.graph.node:
            for i, name in enumerate(node.output):
                if name == output.name:
                    node.output[i] = output_name
        output.name = output_name

代码中input_names和output_names是我们希望改到的名称,做法是遍历网络,若有node的输入层名与要修改的输入层名称相同,则改成新的输入层名。输出层类似。

完整代码

import onnx
def change_input_output_dim(model):
    # Use some symbolic name not used for any other dimension
    sym_batch_dim = "batch"

    # The following code changes the first dimension of every input to be batch-dim
    # Modify as appropriate ... note that this requires all inputs to
    # have the same batch_dim 
    inputs = model.graph.input
    for input in inputs:
        # Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
        # Add checks as needed.
        dim1 = input.type.tensor_type.shape.dim[0]
        # update dim to be a symbolic value
        dim1.dim_param = sym_batch_dim
        # or update it to be an actual value:
        # dim1.dim_value = actual_batch_dim
    
    outputs = model.graph.output
    for output in outputs:
        # Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
        # Add checks as needed.
        dim1 = output.type.tensor_type.shape.dim[0]
        # update dim to be a symbolic value
        dim1.dim_param = sym_batch_dim

def change_input_node_name(model, input_names):
    for i,input in enumerate(model.graph.input):
        input_name = input_names[i]
        for node in model.graph.node:
            for i, name in enumerate(node.input):
                if name == input.name:
                    node.input[i] = input_name
        input.name = input_name
        

def change_output_node_name(model, output_names):
    for i,output in enumerate(model.graph.output):
        output_name = output_names[i]
        for node in model.graph.node:
            for i, name in enumerate(node.output):
                if name == output.name:
                    node.output[i] = output_name
        output.name = output_name


onnx_path = ""
save_path = ""
model = onnx.load(onnx_path)
change_input_output_dim(model)
change_input_node_name(model, ["data"])
change_output_node_name(model, ["fc1"])

onnx.save(model, save_path)

经过修改后的onnx模型输入输出将成为动态batch,可以方便的移植到tensorrt等框架以支持高效推理。

在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

将onnx的静态batch改为动态batch及修改输入输出层的名称 的相关文章

随机推荐

  • Java实现简单版SVM

    Java实现简单版SVM 最近的图像分类工作要用到latent svm 为了更加深入了解svm 自己动手实现一个简单版的 之所以说是简单版 因为没有用到拉格朗日 对偶 核函数等等 而是用最简单的梯度下降法求解 其中的数学原理我参考了http
  • 解决Adobe Flash Player已不再受支持的问题

    1 问题展示 我们在访问某些网站时 可能会出现Adobe Flash Player已不再受支持的问题 具体如下图所示 这会对我们的日常生活需求产生极大的不便 因此迫切需要一个能够解决此问题的方法 其实很简单 具体操作请看下面的步骤 2 下载
  • 移动端+PC端图片预览+手势缩放等功能合集

    话不多说 直接上代码 大家可按需求功能复制使用 window onload function 点击图片进入预览 var Dom document querySelector preview Dom onclick function var
  • 接口如何实现多态

    抽象类是用来继承的 不能被实例化 抽象类里可以有成员变量 接口中没有 1 抽象类里的抽象方法 只有在子类实现了才能使用 2 抽象类里的普通方法 可被子类调用 3 接口里的方法 都被默认修饰为public abstract类型 4 接口里的变
  • 记录1年免费亚马逊AWS云服务器申请方法过程及使用技巧

    转载 http www itbulu com free aws html 早年我们才开始学习网站建设的时候 会看到且也会寻找免费主机商 主要原因是那时候提供免费主机的商家还是比较靠谱的 而且那时候主机商并不是很多且成本也比较大 我们深知听到
  • linux上层app调用驱动底层的过程详解

    APP应用程序 gt 应用框架层 gt 硬件抽象层 gt 硬件驱动程序 一 硬件驱动层 进入kernel drivers文件夹中 创建一文件夹 放入驱动程序 包括头文件 C文件 Makefile Kconfig 同时对drivers下的Ma
  • 滚动屏幕或缩放屏幕,使用节流

    场景 滚动屏幕 onScroll 缩放屏幕 resize 如果需要统计用户滚动屏幕或缩放屏幕的行为作出相应的网页反应 容易导致网络的阻塞 mounted window addEventListener resize this throttl
  • 基于聚类的异常值检测算法依据及python实现

    假设数据集D被聚类算法划分到k个类C C1 C2 CK 对象p 的离群因子of3 p 定义为 与所有类间距的加权平均值 其中 D 为样本数量 Cj 为第j个聚类群体样本数量 d p cj 为样本p与第j个聚类中心的距离 其中cj表示第j个聚
  • LeetCode·每日一题·1851. 包含每个查询的最小区间·优先队列(小顶堆)

    题目 示例 思路 离线查询 输入的结果数组queries 是无序的 如果我们按照输入的queries 本身的顺序逐个查看 时间复杂度会比较高 于是 我们将queries 数组按照数值大小 由小到大逐个查询 这种方法称之为离线查询 位运算 离
  • ExtJS之 Proxy数据代理

    ExtJS之 Proxy数据代理 代理种类截图 ExtJS提供的数据代理主要分为两大类 1 客户端代理 Ext data proxy Client 2 服务器代理 Ext data proxy Server 这两个类 都继承自 Ext da
  • Ansible的简介及部署

    1 ansible简介 1 1 什么是ansible ansible是一款开源自动化平台 是一个配置管理工具 自动化运维工具 1 2 ansible的优点 跨平台支持 人类可读自动化 ansible提供linux Windows unix和
  • 类的一些内置方法

    一 slots 用来取代 dict 优势是省内存 附加功能是只能创建slots 定义好的key 注意 不要乱用 用了就没有 dic 方法了 class Foo slots name age 这里可以是列表或者单个字符串 定义key值 f1
  • 解决matplotlib画图中文乱码

    一 下载字体 以SimHei字体为例 下载SimHei ttf文件 在python环境下输入 import matplotlib print matplotlib path 输出matplotlib的安装环境 放在该路径下的mpl data
  • 彻底删除SVN版本库某一文件夹或文件

    基础描述 要彻底删除SVN版本库某一文件夹或文件 可采取这种方法 举例说明 例 假设SVN库路径为E svn project 库中的目录结构为 Trunk Software test exe 删除Software 目录下的test exe文
  • 赫尔德不等式详细证明

    赫尔德不等式详细证明 k 1n akbk k 1n ak p 1 p k 1n bk q 1 q 1 sum k 1 n left a k b k right leqslant sum k 1 n left a k right p 1 p
  • FPGA实战--等精度频率测量

    首先放置效果图 本次试验中采用的是等精度测频率 等精度测频的原理是产生一个1s的高电平 在高电平中对被测方波进行计数 所测得数字即该波形频率 具体等精度测量原理请参考 http www elecfans com d 591858 html
  • 若依框架前后端分离版本自动生成代码的详细步骤

    1 若依框架的下载和本地运行这里就不介绍了主要讲代码自动生成 只是单表的增删改成 复杂的多表业务逻辑还是需要自己手写的 话不多说直接上图 一 新建模块 本地运行起来后右键新建Module 注意这里的Name 可以和若依类似 也可以自己定义新
  • 基于Axure 8课程设计-前端页面设计-漫画APP界面/UI设计(免费分享.rp文件学习)

    这次的课程设计主要是UI设计 基于Axure我设计了一个类似动漫之家的一个设计界面 以下是效果图 UI首页 点击夜魔侠 点击更新按钮 点击分类按钮 排行 点击专题按钮之后 点击搜索按钮之后 点击 新闻 按钮 点击 轻小说 按钮 点击 我的
  • 关于Qt/C++和QML获取屏幕大小方法的总结

    在桌面应用程序的开发过程中 获取屏幕 桌面 的大小来定位桌面应用所显示的位置 是桌面开发中经常用到的 手段 在Qt开发和QML开发中也不例外 本篇着重介绍Qt获取桌面屏幕大小的两种方法 以及对应的QML中获取桌面屏幕 大小的两种方法 首先上
  • 将onnx的静态batch改为动态batch及修改输入输出层的名称

    文章目录 背景 操作 修改输入输出层 修改输入输出层名称 完整代码 背景 在模型的部署中 为了高效利用硬件算力 常常会需要将多个输入组成一个batch同时输入网络进行推理 这个batch的大小根据系统的负载或者摄像头的路数时刻在变化 因此网