TensorRT(11):python版本序列化保存与加载模型

2023-11-19


TensorRT系列传送门(不定期更新): 深度框架|TensorRT



楼主曾经在TensorRT(7):python版本使用入门一文中简要记录了python版本是序列化与反序列化加载模型的步骤,但因为环境以及TRT版本不同,API也有相当大的变化,这里重新记录下,在windows下,tensorrt8.2.3.0版本下,调用python的API是如何加载模型的。

实验案例:采用 yolov5的onnx模型,进行FP16量化保存模型。
代码案例均来自 TensorRT提供的sample中。
详细可见TensorRT-8.2.3.0\samples\python
在这里插入图片描述

一、序列化保存模型

与C++端序列化保存模型的步骤类似

  • 1、首先定义个log 文件,然后创建一个runtime
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(TRT_LOGGER)
  • 2、建立builder,设置maxBatchSize参数
builder = trt.Builder(TRT_LOGGER)  # 创建一个builder
builder.max_batch_size = 1
  • 3、配置config,如设置fp16等
config = builder.create_builder_config()  # 创建一个congig
config.max_workspace_size = 1 << 20
config.set_flag(trt.BuilderFlag.FP16)
  • 4、解析onnx文件,并通过config序列化生成一个network
network = builder.create_network(EXPLICIT_BATCH)  # 创建一个network
parser = trt.OnnxParser(network, TRT_LOGGER)

model = open(onnx_file_path, 'rb')
if not parser.parse(model.read()):
    for error in range(parser.num_errors):
        print(parser.get_error(error))

network.get_input(0).shape = [1, 3, 640, 640]
print('Completed parsing of ONNX file')
print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
plan = builder.build_serialized_network(network, config)
with open(engine_file_path, "wb") as f:
      f.write(plan)
      print("Completed write Engine")

二、反序列化加载模型

在一中序列化建立好network后,可以调用deserialize_cuda_engine反序列化生成一个 engine

engine = runtime.deserialize_cuda_engine(plan)
print("Completed creating Engine")

如果加载保存在本地的trt模型,可以直接加载engine

 if os.path.exists(engine_file_path):
      # If a serialized engine exists, use it instead of building an engine.
      print("Reading engine from file {}".format(engine_file_path))
      with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
          return runtime.deserialize_cuda_engine(f.read())

三、完整代码

完整代码都可在github上的官网samples查询。
onnx_to_tensorrt.py


def get_engine(onnx_file_path, engine_file_path=""):
    """Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
    def build_engine():
        """Takes an ONNX file and creates a TensorRT engine to run inference with"""
        with trt.Builder(TRT_LOGGER) as builder, builder.create_network(common.EXPLICIT_BATCH) as network, builder.create_builder_config() as config, trt.OnnxParser(network, TRT_LOGGER) as parser, trt.Runtime(TRT_LOGGER) as runtime:
            config.max_workspace_size = 1 << 28 # 256MiB
            builder.max_batch_size = 1
            # Parse model file
            if not os.path.exists(onnx_file_path):
                print('ONNX file {} not found, please run yolov3_to_onnx.py first to generate it.'.format(onnx_file_path))
                exit(0)
            print('Loading ONNX file from path {}...'.format(onnx_file_path))
            with open(onnx_file_path, 'rb') as model:
                print('Beginning ONNX file parsing')
                if not parser.parse(model.read()):
                    print ('ERROR: Failed to parse the ONNX file.')
                    for error in range(parser.num_errors):
                        print (parser.get_error(error))
                    return None
            # The actual yolov3.onnx is generated with batch size 64. Reshape input to batch size 1
            network.get_input(0).shape = [1, 3, 608, 608]
            print('Completed parsing of ONNX file')
            print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
            plan = builder.build_serialized_network(network, config)
            engine = runtime.deserialize_cuda_engine(plan)
            print("Completed creating Engine")
            with open(engine_file_path, "wb") as f:
                f.write(plan)
            return engine

    if os.path.exists(engine_file_path):
        # If a serialized engine exists, use it instead of building an engine.
        print("Reading engine from file {}".format(engine_file_path))
        with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
            return runtime.deserialize_cuda_engine(f.read())
    else:
        return build_engine()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorRT(11):python版本序列化保存与加载模型 的相关文章

随机推荐

  • JAVA发展历程

    Java是一门面向对象的编程语言 不仅吸收了C 语言的各种优点 还摒弃了C 里难以理解的多继承 指针等概念 因此Java语言具有功能强大和简单易用两个特征 Java语言作为静态面向对象编程语言的代表 极好地实现了面向对象理论 允许程序员以优
  • c语言之数据结构学习心得

    写在前面 你们好 我是小庄 很高兴能和你们一起学习c语言 如果您对编程感兴趣的话可关注我的动态 写博文是一种习惯 在这过程中能够梳理知识和巩固知识点 一 绪论 1 什么是数据 数据元素 数据项 数据对象 数据结构 1 数据 客观事物的符号表
  • 在eclipse里建立包中包

    工具 原料 工具软件 j2EE eclipse 语言 Java 方法 步骤 1 在src文件夹右击 new package 见下图 2 点击finish 3 在com包右击new package 4
  • 工控上位机程序为什么只能用C语言?

    工控上位机程序并不只能用C 开发 实际上在工业自动化领域中 常见的上位机开发语言包括但不限于以下几种 C C 是一种常用的编程语言 在工控领域中被广泛使用 它具有良好的面向对象特性和丰富的类库支持 可以实现高性能的上位机程序开发 C C C
  • Allegro使用经验笔记

    一 安装 SPB15 2 CD1 3 安装1 2 第3为库 不安装 License安装 设置环境变量Lm license file D Cadencelicense Dat 修改License中SERVER Yyh ANY 5280为SER
  • Typora快捷键大全

    1 字体编辑 1 1 大小 大小 ctr 数字 或 ctr 加减号 或 1 2 加粗 加粗 ctr b 1 3 倾斜 倾斜 ctr i 1 4 下划线 下划线 ctr u 1 5 删除线 删除线 alt shift 5 1 6 上标 上标
  • YOLOv8改进开源

    大致介绍一下AI全栈技术社区的相关内容 主要涵盖了YOLO全系列模型的改进 量化 蒸馏 剪枝以及不同工具链的使用 同时也涵盖多目标跟踪 语义分割 3D目标检测 AI模型部署等内容 具体内容小伙伴们可以参考下面的目录部分 所有内容均有答疑服务
  • 学习-Python字符串之格式化

    第1关 学习 Python字符串之格式化 任务描述 本关任务 给定一个列表 计算列表内所有数据标准差 结果保留小数点后 2 位 相关知识 为了完成本关任务 你需要掌握 的使用 format 的使用 Template 的使用 在之前的实训中
  • 数据分析01——Anaconda安装/Anaconda中的pip换源/jupyter配置

    0 前言 数据分析三大模块知识 numpy 数组计算 pandas 基于numpy开发 用于数据清洗和数据分析 matplotlib 实现数据可视化 1 Anaconda安装 安装Anaconda 注意安装路径不一定是c盘 但是安装目录不要
  • Python Excel操作模块XlsxWriter之写入worksheet.write()

    worksheet write wirte row col args 向工作表单元格写入普通的数据 参数 row 单元格所在的行 索引从0开始计数 col 单元格所在的列 索引从0开始计数 args 传递到子方法的附加参数诸如数字 字符串
  • 端口介绍

    文章来源 https m toutiaocdn com group 6680437870504706572 app news article timestamp 1563010542 req id 201907131735410100230
  • Linux lvm管理讲解及命令

    作者 小刘在C站 个人主页 小刘主页 每天分享云计算网络运维课堂笔记 努力不一定有回报 但一定会有收获加油 一起努力 共赴美好人生 夕阳下 是最美的绽放 树高千尺 落叶归根人生不易 人间真情 前言 目录 一 lvm管理 1 Logical
  • mysql sql优化方法_一个MySql Sql 优化技巧分享

    有天发现一个带inner join的sql 执行速度虽然不是很慢 0 1 0 2 但是没有达到理想速度 两个表关联 且关联的字段都是主键 查询的字段是唯一索引 sql如下 SELECTp item token p item product
  • 如何在小程序实现人脸识别的方法

    1 获取用户授权 在小程序中实现人脸识别需要先获取用户的授权 用户需要允许小程序访问他们的摄像头和图像数据 这样才能进行人脸识别 2 采集图像数据 在获得用户授权后 小程序可以通过摄像头或者相册功能 采集用户的面部图像数据 3 使用图像处理
  • Java-private构造方法

    private 构造函数一般用于Singleton模式 指的是整个应用只有本类的一个对象 一般这种类都有一个类似getInstance 的方法 class A public String name 构造函数限定为private 不可以直接创
  • 标准差(Standard Deviation), 标准误差(Standard error),变异系数 (Coefficient of Variance )的区别与联系

    标准差 Standard Deviation 中文环境中又常称均方差 是离均差平方的算术平均数的平方根 用 表示 标准差是方差的算术平方根 标准差能反映一个数据集的离散程度 平均数相同的两组数据 标准差未必相同 标准误差 Standard
  • 学习Flask之Flask-Login 用户会话管理

    Flask Login 用户控制用户会话管理 简单点说 就是控制登录 如果是自己写的登录系统 一般都是通过操作session 然后后台根据session 来判断权限 Flask Login 就是负责这部分 直接开始 安装 pip insta
  • System.Data.OracleClient 需要 Oracle 客户端软件 version 8.1.7 或更高版本

    同学的电脑连接实验室的服务器时出现 System Data OracleClient 需要 Oracle 客户端软件 version 8 1 7 或更高版本 而我自己的电脑可以轻松连接服务器的数据库 首先 实验室用的是Oracle 12c
  • 力扣 942. 增减字符串匹配 双指针解法C++

    给定只含 I 增大 或 D 减小 的字符串 S 令 N S length 返回 0 1 N 的任意排列 A 使得对于所有 i 0 N 1 都有 如果 S i I 那么 A i lt A i 1 如果 S i D 那么 A i gt A i
  • TensorRT(11):python版本序列化保存与加载模型

    TensorRT系列传送门 不定期更新 深度框架 TensorRT 文章目录 一 序列化保存模型 二 反序列化加载模型 三 完整代码 楼主曾经在TensorRT 7 python版本使用入门一文中简要记录了python版本是序列化与反序列化