使用tensorrt加速深度学习模型推断

2023-12-05

此博客介绍如何将resnet101模型在CIFAR100数据集的分类任务,使用tensorrt部署。

完整代码如下

1.import以及数据加载、构建engine函数

import argparse
import os

import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.models as models

import time


import numpy as np
import tensorrt as trt
import common
import torchvision.transforms as transforms

TRT_LOGGER = trt.Logger()
os.environ["CUDA_VISIBLE_DEVICES"] = '0'  # 指定0号GPU可用


# mean and std of cifar100 dataset
CIFAR100_TRAIN_MEAN = (
    0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401,
                      0.2564384629170883, 0.27615047132568404)

def get_test_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=True):
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    cifar100_test = torchvision.datasets.CIFAR100(
        root='./data', train=False, download=True, transform=transform_test)
    cifar100_test_loader = DataLoader(
        cifar100_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar100_test_loader


def ONNX_build_engine(onnx_file_path, trt_file):
    G_LOGGER = trt.Logger(trt.Logger.WARNING)
    explicit_batch = 1 << (int)(
        trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    batch_size = 64  
    with trt.Builder(G_LOGGER) as builder, builder.create_network(explicit_batch) as network, \
            trt.OnnxParser(network, G_LOGGER) as parser:
        builder.max_batch_size = batch_size
        config = builder.create_builder_config()
        config.set_memory_pool_limit(
            trt.MemoryPoolType.WORKSPACE, common.GiB(1))
        config.set_flag(trt.BuilderFlag.FP16)
        print('Loading ONNX file from path {}...'.format(onnx_file_path))
        with open(onnx_file_path, 'rb') as model:
            print('Beginning ONNX file parsing')
            parser.parse(model.read())
        print('Completed parsing of ONNX file')
        print('Building an engine from file {}; this may take a while...'.format(
            onnx_file_path))

        profile = builder.create_optimization_profile()
        profile.set_shape("input", (1, 3, 32, 32),
                          (1, 3, 32, 32), (batch_size, 3, 32, 32))
        config.add_optimization_profile(profile)
        engine = builder.build_serialized_network(network, config)
        print("Completed creating Engine")
        with open(trt_file, "wb") as f:
            f.write(engine)
        return engine

2.导入官方模型及CIFAR100数据集


if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    parser.add_argument('-gpu', action='store_true',
                        default=True, help='use gpu or not')
    parser.add_argument('-b', type=int, default=32,
                        help='batch size for dataloader')
    args = parser.parse_args()
    print(args)

    cifar100_test_loader = get_test_dataloader(
        CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD,
        num_workers=1,
        batch_size=args.b)


    device = "cuda" if args.gpu else "cpu"
    net = models.resnet101(pretrained=True)
    net = net.to(device)
    # # print(net)
    net.eval()

3.不采用tensort的推断时间

#%%
    t1 = time.time()
    for n_iter, (image, label) in enumerate(cifar100_test_loader):
        pred = net(image.to(device))
        # print(pred.shape)
    t2 = time.time()
    print(t2-t1)

耗时约为8~9s。

4.采用tensort加速—使用tensorrt 库

4.1 导出onnx模型

#%% save onnx 
    input = torch.rand([1, 3, 32, 32]).to(device)
    onnx_file = "resnet101.onnx"

    if  os.path.exists(onnx_file):
        os.remove(onnx_file)
    torch.onnx.export(net, input, onnx_file,
                      input_names=['input'],  # the model's input names
                      output_names=['output'],
                      dynamic_axes={'input': {0: 'batch_size'},
                                    'output': {0: 'batch_size'}},
                      # opset_version=12,
                      )
    print("onnx file generated!")

4.2 生成tensorrt engine 文件

# %%generate tensorrt engine file
    trt_file = "resnet101.trt"

    ONNX_build_engine(onnx_file, trt_file)
    print("trt file generated!")

4.3 deserialize

    trt_file = "resnet101.trt"
    runtime = trt.Runtime(TRT_LOGGER)
    with open(trt_file, 'rb') as f:
        engine = runtime.deserialize_cuda_engine(f.read())
        print("Completed creating Engine")
    context = engine.create_execution_context()
    context.set_binding_shape(0, (16, 3, 32, 32))

    inputs, outputs, bindings, stream = common.allocate_buffers(engine, 32)

4.4 推断

    t1 = time.time()
    label_ls = []
    pred_ls = []
    for n_iter, (image, label) in enumerate(cifar100_test_loader):
        # print("iteration: {}\ttotal {} iterations".format(n_iter + 1, len(cifar100_test_loader)))
        # print(image)
        inputs[0].host = image.numpy()

        trt_outputs = common.do_inference(
            context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream, batch_size=32)
        label_ls.extend(label.numpy())
        pred_ls.extend(np.array(trt_outputs[0]).reshape(
            [-1, 100]).argmax(1).tolist())
        # print((np.array(pred_ls)[:10000]==np.array(label_ls)[:10000]).sum())
    t2 = time.time()
    print(t2-t1)

耗时约为4.3s,是用我的笔记本 上的GPU RTX 3050可以实现两倍左右的加速。

5.采用tensort加速—使用torch2trt库

nvidia还有torch2trt Python包,可用于一键tensorrt加速。

其安装可参考 https://github.com/NVIDIA-AI-IOT/torch2trt .

git clone https://github.com/NVIDIA-AI-IOT/torch2trt
cd torch2trt
python setup.py install

torch2trt的使用可参考 github torch2trt

    from torch2trt import torch2trt
    inputs = torch.rand([1, 3, 32, 32]).to(device)
    model_trt = torch2trt(net, [inputs], fp16_mode=True)

    t1 = time.time()
    label_ls = []
    pred_ls = []
    for n_iter, (image, label) in enumerate(cifar100_test_loader):

        output_trt = model_trt(image.to(device))

    t2 = time.time()
    print(t2-t1)

使用起来不要太easy!

完整代码可参考 https://github.com/L0-zhang/tentorrt_demo/tree/main

参考文献

[1] csdn pytorch TensorRT 官方例子
[2] https://github.com/NVIDIA-AI-IOT/torch2trt
[3] https://github.com/L0-zhang/tentorrt_demo/tree/main

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用tensorrt加速深度学习模型推断 的相关文章

  • torch.unique() 中的参数“dim”如何工作?

    我试图提取矩阵每一行中的唯一值并将它们返回到同一个矩阵中 重复值设置为 0 例如 我想转换 torch Tensor 1 2 3 4 3 3 4 1 6 3 5 3 5 4 to torch Tensor 1 2 3 4 0 0 0 1 6
  • 检查 PyTorch 张量在 epsilon 内是否相等

    如何检查两个 PyTorch 张量在语义上是否相等 考虑到浮点错误 我想知道元素是否仅相差一个小的 epsilon 值 在撰写本文时 这是最新稳定版本 0 4 1 中的一个未记录的函数 但文档位于master unstable branch
  • PyTorch 中的截断反向传播(代码检查)

    我正在尝试在 PyTorch 中实现随时间截断的反向传播 对于以下简单情况K1 K2 我下面有一个实现可以产生合理的输出 但我只是想确保它是正确的 当我在网上查找 TBTT 的 PyTorch 示例时 它们在分离隐藏状态 将梯度归零以及这些
  • 无法将 cuda:0 设备类型张量转换为 numpy。首先使用 Tensor.cpu() 将张量复制到主机内存

    我试图展示 GAN 网络在某些指定时期的结果 打印当前结果的功能之前是在 TF 中使用的 我需要换成pytorch def show result G net z num epoch show False save False path r
  • RuntimeError:维度指定为 0 但张量没有维度

    我试图使用 MNIST 数据集实现简单的 NN 但我不断收到此错误 将 matplotlib pyplot 导入为 plt import torch from torchvision import models from torchvisi
  • 预训练 Transformer 模型的配置更改

    我正在尝试为重整变压器实现一个分类头 分类头工作正常 但是当我尝试更改配置参数之一 config axis pos shape 即模型的序列长度参数时 它会抛出错误 Reformer embeddings position embeddin
  • 我可以使用逻辑索引或索引列表对张量进行切片吗?

    我正在尝试使用列上的逻辑索引对 PyTorch 张量进行切片 我想要与索引向量中的 1 值相对应的列 切片和逻辑索引都是可能的 但是它们可以一起吗 如果是这样 怎么办 我的尝试不断抛出无用的错误 类型错误 使用 ByteTensor 类型的
  • 如何在 google colab 中运行 matlab .m 文件

    我目前正在尝试运行这个存储库https github com Fanziapril mvfnet https github com Fanziapril mvfnet这需要一个步骤 Run the Matlab ModelGeneratio
  • PyTorch 教程错误训练分类器

    我刚刚开始 PyTorch 教程使用 PyTorch 进行深度学习 60 分钟闪电战我应该补充一点 我之前没有编写过任何 python 但其他语言 如 Java 现在 我的代码看起来像 import torch import torchvi
  • 如何在pytorch中查看DataLoader中的数据

    我在 Github 上的示例中看到类似以下内容 如何查看该数据的类型 形状和其他属性 train data MyDataset int 1e3 length 50 train iterator DataLoader train data b
  • LSTM 错误:AttributeError:“tuple”对象没有属性“dim”

    我有以下代码 import torch import torch nn as nn model nn Sequential nn LSTM 300 300 nn Linear 300 100 nn ReLU nn Linear 300 7
  • 从打包序列中获取每个序列的最后一项

    我试图通过 GRU 放置打包和填充的序列 并检索每个序列最后一项的输出 当然我的意思不是 1项目 但实际上是最后一个 未填充的项目 我们预先知道序列的长度 因此应该很容易为每个序列提取length 1 item 我尝试了以下方法 impor
  • Blenderbot 微调

    我一直在尝试微调 HuggingFace 的对话模型 Blendebot 我已经尝试过官方拥抱脸网站上给出的传统方法 该方法要求我们使用 trainer train 方法来完成此操作 我使用 compile 方法尝试了它 我尝试过使用 Py
  • 如何使用pytorch构建多任务DNN,例如超过100个任务?

    下面是使用 pytorch 为两个回归任务构建 DNN 的示例代码 这forward函数返回两个输出 x1 x2 用于大量回归 分类任务的网络怎么样 例如 100 或 1000 个输出 对所有输出 例如 x1 x2 x100 进行硬编码绝对
  • pytorch 的 IDE 自动完成

    我正在使用 Visual Studio 代码 最近尝试了风筝 这两者似乎都没有 pytorch 的自动完成功能 这些工具可以吗 如果没有 有人可以推荐一个可以的编辑器吗 谢谢你 使用Pycharmhttps www jetbrains co
  • PyTorch 中的连接张量

    我有一个张量叫做data形状的 128 4 150 150 其中 128 是批量大小 4 是通道数 最后 2 个维度是高度和宽度 我有另一个张量叫做fake形状的 128 1 150 150 我想放弃最后一个list array从第 2 维
  • 保存具有自定义前向功能的 Bert 模型并将其置于 Huggingface 上

    我创建了自己的 BertClassifier 模型 从预训练开始 然后添加由不同层组成的我自己的分类头 微调后 我想使用 model save pretrained 保存模型 但是当我打印它并从预训练上传时 我看不到我的分类器头 代码如下
  • 如何在 PyTorch 中对子集使用不同的数据增强

    如何针对不同的情况使用不同的数据增强 转换 Subset在 PyTorch 中吗 例如 train test torch utils data random split dataset 80000 2000 train and test将具
  • 在Pytorch中计算欧几里得范数..理解和实现上的麻烦

    我见过另一个 StackOverflow 线程讨论计算欧几里德范数的各种实现 但我很难理解特定实现的原因 如何工作 该代码可以在 MMD 指标的实现中找到 https github com josipd torch two sample b
  • TensorFlow 相当于 PyTorch 的 Transforms.Normalize()

    我正在尝试推断最初在 PyTorch 中构建的 TFLite 模型 我一直在遵循PyTorch 实现 https github com leoxiaobin deep high resolution net pytorch blob 1ee

随机推荐

  • Latex正文引用图片编号,防止某张图片删除或调整导致正文序号对应错误

    一 背景 Latex真的是非常好用的论文排版工具 虽然不像word一样是 所见即所得 的可视化方式 但完全不用管格式 包括图片的排版 文字的缩进等等 这在word里调整起来真的是非常麻烦 特别是某个段落 图片修改后 又要重新调整格式 非常的
  • Ubuntu20.04安装向日葵、开机自启、解决windows系统远程黑屏(笔记)

    这里写目录标题 动机 1 Ubuntu20 04 安装向日葵 2 设置开机自启 3 解决windows不可远程的问题 4 大公告成 动机 办公室有个工作站 要比我的笔记本的CPU稍微好一点 用来跑陆面过程 我信心满满的装了个Ubuntu20
  • 什么是离岸公司?有什么作用?

    离岸公司是泛指在离岸法区内依据其离岸公司法规范成立的有限责任公司或股份有限公司 这些公司不能在注册地经营 而主要是在离岸法区以外的地方开展业务活动 离岸公司的主要特点包括高度保密性 无外汇管制和减免税务负担 离岸公司的作用主要有以下几个方面
  • 销售人员一定要知道的6种获取电话号码的方法

    对于销售来说 电话销售是必须要知道的销售方法 也是销售生涯中的必经之路 最开始我们并不清楚这么电话是从哪里来的 也不清楚是通过哪些方法渠道获取 那么今天就来分享给各位销售人员获取客户电话号码的方法 1 打印自己的名片 在工作当中少不了接触其
  • 5.【自动驾驶与机器人中的SLAM技术】2D点云的scan matching算法 和 检测退化场景的思路

    目录 1 基于优化的点到点 线的配准 2 对似然场图像进行插值 提高匹配精度 3 对二维激光点云中会对SLAM功能产生退化场景的检测 4 在诸如扫地机器人等这样基于2D激光雷达导航的机器人 如何处理悬空 低矮物体 5 也欢迎大家来我的读书号
  • 大Ⅲ周记11

    1 本周学习了mysql数据库操作的相关知识 根据课设要求完成了压降系统数据库表的设计 2 计算机网络完成了所有章节的作业 开始进入复习阶段 预计下周完成一至二章的复习作业
  • leetcode:93. 复原 IP 地址

    复原 IP 地址 中等 1 4K 相关企业 有效 IP 地址 正好由四个整数 每个整数位于 0 到 255 之间组成 且不能含有前导 0 整数之间用 分隔 例如 0 1 2 201 和 192 168 1 1 是 有效 IP 地址 但是 0
  • 最近在对接电商供应链,说说开放平台API接口

    B2B电商开放平台的设计需要从以下几面去思考 开放平台API接口 的接入 主要是从功能需求的角度 设计满足业务需求的接口及对应的字段 平台与商家之间信息的对接 对接的方法有哪些 对接过程中需要可能会遇到什么问题 同步开关及权限的设计 处理信
  • 鸿蒙4.0开发笔记之ArkTS装饰器语法基础@Prop@Link@State状态装饰器(十二)

    文章目录 一 哪些是状态装饰器 二 State Prop Link状态传递的核心规则 三 状态装饰器练习 一 哪些是状态装饰器 1 State 被装饰拥有其所属组件的状态 可以作为其子组件单向和双向同步的数据源 当其数值改变时 会引起相关组
  • Nessus简单介绍与安装

    1 Nessus简介 Nessus号称是世界上最流行的漏洞扫描程序 全世界有超过75000个组织在使用它 该工具提供完整的电脑漏洞扫描服务 并随时更新其漏洞数据库 Nessus不同于传统的漏洞扫描软件 Nessus可同时在本机或远端上遥控
  • WebGL笔记:矩阵平移的数学原理和实现

    矩阵平移的数学原理 让向量OA位移 x方向 tx y方向 ty z方向 tz 最终得到向量OB 矩阵平移的应用 再比如我要让顶点的x移动0 1 y移动0 2 z移动0 3 1 顶点着色器核心代码
  • 有效表达观点的艺术

    有效表达观点的艺术 在人际交往中 有效地表达自己的观点是建立良好关系和实现有效沟通的关键 然而 这并不总是易如反掌 有时候 我们可能会遇到表达困难 或者我们的观点可能被误解 本文将探讨如何有效地表达观点 以及掌握说话的艺术的重要性 首先 清
  • 人工智能:开启未来商业新篇章

    人工智能 开启未来商业新篇章 随着科技的快速发展 人工智能 AI 在商业领域的应用越来越广泛 成为企业把握未来商业机遇的重要方向 本文将探讨人工智能如何重塑商业格局 为企业提供新的增长点 以及企业如何抓住AI的商业契机 一 AI重塑商业格局
  • 机器人学英语

    我的prompt i want to you act as an english language teacher asistant to help me study english you could teach me in such a
  • 详解Hotspot的经典7种垃圾收集器原理特点与组合搭配

    详解Hotspot的经典7种垃圾收集器原理特点与组合搭配 HotSpot共有7种垃圾收集器 3个新生代垃圾收集器 3个老年代垃圾收集器 以及G1 一共构成7种可供选择的垃圾收集器组合 新生代与老年代垃圾收集器之间形成6种组合 每个新生代垃圾
  • 在深圳月入一万的很丢人吗

    在深圳 月入一万的收入是否丢人 这是一个很主观的问题 因为每个人的生活需求和价值观不同 从经济学的角度来看 深圳作为中国的经济特区和一线城市 其生活成本相对较高 从这个角度看 月入一万的收入在某种程度上可能不足以满足一些人的生活需求 根据最
  • 给自己泡了一壶茶

    清晨 当第一缕阳光透过窗户照亮了房间 我慵懒地爬起床 开始享受新的一天 我泡了一壶早茶 浅浅的茶香立刻弥漫在空气中 让我感到宁静而放松 我坐在窗边 静静地看着窗外的世界 清晨的街道上 行人和车辆都还不多 显得格外的宁静 微风吹过树叶 带来阵
  • 拍图识字软件哪个好用?这些好用的软件推荐给你们

    在快节奏的现代生活中 你可能会遇到需要从图片中获取文字信息的情况 无论是读书 工作还是生活中 有时候会需要从图片中提取文字 当你收到了一份手写的便签或菜单 上面的字迹可能很模糊 或者你需要在没有文字的地方快速获取信息 这时 你可能会想 如果
  • 详解十大经典排序算法(四):希尔排序(Shell Sort)

    算法原理 希尔排序是一种基于插入排序的排序算法 也被称为缩小增量排序 它通过将待排序的序列分割成若干个子序列 对每个子序列进行插入排序 然后逐步缩小增量 最终使整个序列有序 算法描述 希尔排序 Shell Sort 是一种基于插入排序的算法
  • 使用tensorrt加速深度学习模型推断

    使用tensorrt加速深度学习模型推断 1 import以及数据加载 构建engine函数 2 导入官方模型及CIFAR100数据集 3 不采用tensort的推断时间 4 采用tensort加速 使用tensorrt 库 4 1 导出o