tensorRT模型推理时动态shape

2023-11-05

动态shape

所谓动态shape就是编译时指定可动态的范围【L-H】,推理时可以允许L<=shape<=H。在全卷积网络中我们通常就是有这个诉求的,推理时的shape是可以动态改变的,不一定要限制死,这个动态shape不一定只宽高,还指batchsize也是动态的。

实现动态shape的操作主要就是修改下面提到的两个方面就行了。

1.构建网络时
1.1.必须在模型定义时,输入维度给定为-1,否则该维度不会动态。注意一下两点:

  • 若onnx文件,则onnx文件打开后如果维度是字母或-1的话那么它的维度就被认为是动态的。
  • 如果你的模型中存在reshape类操作,那么reshape的参数必须随动态进行计算。而大部分时候这都是问题。除非你是全卷积模型,否则大部分时候只需要为batch_size维度设置为动态,其他维度尽量避免设置动态

1.2.配置profile:

  • create: builder->createOptimizationProfile()
  • set: setDimensions()设置kMIN, kOPT, kMAX的一系列输入尺寸范围
  • add:config->addOptimizationProfile(profile);添加profile到网络配置中

2.推理阶段时

2.1.需要在选择profile的索引后设置input的维度即shape:

//这里就是指的把输入的shape设置成(1,1,3,3),Bindings的第0个索引表示输入(具体bindings的概念参见文章:tensorRT实现模型的推理过程)
execution_context->setBindingDimensions(0, nvinfer1::Dims4(1, 1, 3, 3))

代码示例

和之前全连接的代码唯一的区别就是两个点,一个是网络结构的定义换成了CNN,另一个是动态shape的配置createOptimizationProfile。OptimizationProfile是一个优化配置文件,用来指定输入的shape可以变换的范围的,不要被优化两个字蒙蔽了双眼,其实就是为了告诉tensorRT我的shape是什么范围!


// tensorRT include
#include <NvInfer.h>
#include <NvInferRuntime.h>

// cuda include
#include <cuda_runtime.h>

// system include
#include <stdio.h>
#include <math.h>

#include <iostream> 
#include <fstream> // 后面要用到ios这个库
#include <vector>

using namespace std;

class TRTLogger : public nvinfer1::ILogger{
public:
    virtual void log(Severity severity, nvinfer1::AsciiChar const* msg) noexcept override{
        if(severity <= Severity::kINFO){
            printf("%d: %s\n", severity, msg);
        }
    }
} logger;

nvinfer1::Weights make_weights(float* ptr, int n){
    nvinfer1::Weights w;
    w.count = n;
    w.type = nvinfer1::DataType::kFLOAT;
    w.values = ptr;
    return w;
}

bool build_model(){
    TRTLogger logger;

    // ----------------------------- 1. 定义 builder, config 和network -----------------------------
    nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);
    nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();
    nvinfer1::INetworkDefinition* network = builder->createNetworkV2(1);

    // 构建一个模型
    /*
        Network definition:

        image
          |
        conv(3x3, pad=1)  input = 1(指的是channel), output = 1, bias = True     w=[[1.0, 2.0, 0.5], [0.1, 0.2, 0.5], [0.2, 0.2, 0.1]], b=0.0
          |
        relu
          |
        prob
    */


    // ----------------------------- 2. 输入,模型结构和输出的基本信息 -----------------------------
    const int num_input = 1;
    const int num_output = 1;
    float layer1_weight_values[] = {
        1.0, 2.0, 3.1, 
        0.1, 0.1, 0.1, 
        0.2, 0.2, 0.2
    }; // 行优先
    float layer1_bias_values[]   = {0.0};

    // 如果要使用动态shape,必须让NetworkDefinition的维度定义为-1,in_channel是固定的
    nvinfer1::ITensor* input = network->addInput("image", nvinfer1::DataType::kFLOAT, nvinfer1::Dims4(-1, num_input, -1, -1));
    nvinfer1::Weights layer1_weight = make_weights(layer1_weight_values, 9);
    nvinfer1::Weights layer1_bias   = make_weights(layer1_bias_values, 1);
    //网络定义卷积层
    auto layer1 = network->addConvolution(*input, num_output, nvinfer1::DimsHW(3, 3), layer1_weight, layer1_bias);
    layer1->setPadding(nvinfer1::DimsHW(1, 1));

    auto prob = network->addActivation(*layer1->getOutput(0), nvinfer1::ActivationType::kRELU); // *(layer1->getOutput(0))
     
    // 将我们需要的prob标记为输出
    network->markOutput(*prob->getOutput(0));

    int maxBatchSize = 10;
    printf("Workspace Size = %.2f MB\n", (1 << 28) / 1024.0f / 1024.0f);
    // 配置暂存存储器,用于layer实现的临时存储,也用于保存中间激活值
    config->setMaxWorkspaceSize(1 << 28);

    // --------------------------------- 2.1 关于profile ----------------------------------
    // profile就是关于实现模型编译时动态shape的配置!如果模型有多个输入,则必须多个profile
    auto profile = builder->createOptimizationProfile();

    // 配置最小允许1 x 1 x 3 x 3,kMIN表示配置的最小值
    profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMIN, nvinfer1::Dims4(1, num_input, 3, 3));
    //kOPT是表示配置的最优值
    profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kOPT, nvinfer1::Dims4(1, num_input, 3, 3));

    // 配置最大允许10 x 1 x 5 x 5,kMAX表示配置的最大值
    // if networkDims.d[i] != -1, then minDims.d[i] == optDims.d[i] == maxDims.d[i] == networkDims.d[i]
    profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMAX, nvinfer1::Dims4(maxBatchSize, num_input, 5, 5));
    config->addOptimizationProfile(profile);

    nvinfer1::ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
    if(engine == nullptr){
        printf("Build engine failed.\n");
        return false;
    }

    // -------------------------- 3. 序列化 ----------------------------------
    // 将模型序列化,并储存为文件
    nvinfer1::IHostMemory* model_data = engine->serialize();
    FILE* f = fopen("engine.trtmodel", "wb");
    fwrite(model_data->data(), 1, model_data->size(), f);
    fclose(f);

    // 卸载顺序按照构建顺序倒序
    model_data->destroy();
    engine->destroy();
    network->destroy();
    config->destroy();
    builder->destroy();
    printf("Done.\n");
    return true;
}

vector<unsigned char> load_file(const string& file){
    ifstream in(file, ios::in | ios::binary);
    if (!in.is_open())
        return {};

    in.seekg(0, ios::end);
    size_t length = in.tellg();

    std::vector<uint8_t> data;
    if (length > 0){
        in.seekg(0, ios::beg);
        data.resize(length);

        in.read((char*)&data[0], length);
    }
    in.close();
    return data;
}

void inference(){
    // ------------------------------- 1. 加载model并反序列化 -------------------------------
    TRTLogger logger;
    auto engine_data = load_file("engine.trtmodel");
    nvinfer1::IRuntime* runtime   = nvinfer1::createInferRuntime(logger);
    nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(engine_data.data(), engine_data.size());
    if(engine == nullptr){
        printf("Deserialize cuda engine failed.\n");
        runtime->destroy();
        return;
    }

    nvinfer1::IExecutionContext* execution_context = engine->createExecutionContext();
    cudaStream_t stream = nullptr;
    cudaStreamCreate(&stream);

    /*
        Network definition:

        image
          |
        conv(3x3, pad=1)  input = 1, output = 1, bias = True     w=[[1.0, 2.0, 0.5], [0.1, 0.2, 0.5], [0.2, 0.2, 0.1]], b=0.0
          |
        relu
          |
        prob
    */

    // ------------------------------- 2. 输入与输出 -------------------------------
    float input_data_host[] = {
        // batch 0
        1,   1,   1,
        1,   1,   1,
        1,   1,   1,

        // batch 1
        -1,   1,   1,
        1,   0,   1,
        1,   1,   -1
    };
    float* input_data_device = nullptr;

    // 3x3输入,对应3x3输出
    int ib = 2;
    int iw = 3;
    int ih = 3;
    float output_data_host[ib * iw * ih];
    float* output_data_device = nullptr;
    cudaMalloc(&input_data_device, sizeof(input_data_host));
    cudaMalloc(&output_data_device, sizeof(output_data_host));
    cudaMemcpyAsync(input_data_device, input_data_host, sizeof(input_data_host), cudaMemcpyHostToDevice, stream);


    // ------------------------------- 3. 推理 -------------------------------
    // 明确当前推理时,使用的数据输入大小,setBindingDimensions第一个参数为0指的是输入
    execution_context->setBindingDimensions(0, nvinfer1::Dims4(ib, 1, ih, iw));
    float* bindings[] = {input_data_device, output_data_device};
    bool success      = execution_context->enqueueV2((void**)bindings, stream, nullptr);
    cudaMemcpyAsync(output_data_host, output_data_device, sizeof(output_data_host), cudaMemcpyDeviceToHost, stream);
    cudaStreamSynchronize(stream);


    // ------------------------------- 4. 输出结果 -------------------------------
    for(int b = 0; b < ib; ++b){
        printf("batch %d. output_data_host = \n", b);
        for(int i = 0; i < iw * ih; ++i){
            printf("%f, ", output_data_host[b * iw * ih + i]);
            if((i + 1) % iw == 0)
                printf("\n");
        }
    }

    printf("Clean memory\n");
    cudaStreamDestroy(stream);
    cudaFree(input_data_device);
    cudaFree(output_data_device);
    execution_context->destroy();
    engine->destroy();
    runtime->destroy();
}

int main(){

    if(!build_model()){
        return -1;
    }
    inference();
    return 0;
}
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

tensorRT模型推理时动态shape 的相关文章

  • 时序预测

    时序预测 MATLAB实现具有外生回归变量的ARIMAX时间序列预测 含AR MA ARIMA SARIMA VAR对比 目录 时序预测 MATLAB实现具有外生回归变量的ARIMAX时间序列预测 含AR MA ARIMA SARIMA V
  • 机器学习(二)深度学习实战-使用Kera预测人物年龄

    问题描述 我们的任务是从一个人的面部特征来预测他的年龄 用 Young Middle Old 表示 我们训练的数据集大约有19906多张照片及其每张图片对应的年龄 全是阿三的头像 测试集有6636张图片 首先我们加载数据集 然后我们通过深度

随机推荐

  • 本地部署体验LISA模型(LISA≈图像分割基础模型SAM+多模态大语言模型LLaVA)

    GitHub地址 https github com dvlab research LISA 该项目论文paper reading https blog csdn net Transfattyacids article details 132
  • jquery.webcam进行摄像头拍照

    最近由于项目要进行人像采集 所以就涉及到在web页面调用摄像头 进行拍照来获取图片 可以初来乍到 这技术又不是杠杠滴 所以在面对这有实现想法 但是又没有实现手段的时候 还是按照往常惯例找度娘 这个搜索过程可谓是无比艰辛 由于关键字不准确迟迟
  • WDK李宏毅学习笔记第十八周01_Meta learning-MAML and Gradient descent as LSTM

    Meta learning MAML and Gradient descent as LSTM 文章目录 Meta learning MAML and Gradient descent as LSTM 摘要 1 Meta learning
  • LO Frequency Plan

    概述 LO DIV是位于VCO和mixer之间的模块 其作用是分频和驱动长走线 设计难点在于底噪 不同的band有不同的频率覆盖范围 为了减小VCO的设计难度需要选择合适的分频方案 E UTRA规定的band与频率的对应关系在3GPP或wi
  • GNU/Linux下有多少是GNU的?

    原文地址 http coolshell cn articles 4826 html more 4826 一个葡萄牙的学生写了一篇文章 How much GNU is there in GNU Linux GNU Linux下有多少是GNU的
  • java模拟HTTP请求工具

    import org slf4j Logger import org slf4j LoggerFactory import java io BufferedReader import java io DataOutputStream imp
  • sqli-labs/Less-10

    这一关提示我们使用布尔和时间盲注相结合的做法 我们先去判断一下注入类型 输入1 and 1 2 存在回显 为字符型 输入1 存在回显 而且回显还一模一样 输入1 存在回显 而且回显当然是一摸一样的啦 我怀疑一直都是如此输出 所以根本不能使用
  • j2ee规范认识

    完成了J2EE视频的学习 三个系列的视频感觉走的是那么的艰难 在懵懵懂懂中进行着 在视频进行的时候已经对J2EE以及EJB的大体框架进行笔记记录和框架整理 接下来对在学习过程中的一些关键点进行总结 J2EE是什么 要想知道J2EE是什么就要
  • Open vSwitch流表查找分析

    流表查找过程是Open vSwitch核心中的核心 在此之前 庾志辉写过关于对Open vSwitch 下文简称OVS 源代码分析的系列博客 链接如下 http blog csdn net yuzhihui no1 article deta
  • 【漏洞复现】Microsoft Office MSDT 远程代码执行漏洞 (CVE-2022-30190)

    0x01 Microsoft Office Microsoft Office是由Microsoft 微软 公司开发的一套办公软件套装 常用组件有 Word Excel PowerPoint等 0x02 漏洞简介 该文档使用 Word 远程模
  • 如何连接到虚拟服务器上,虚拟主机如何连接服务器的

    虚拟主机如何连接服务器的 内容精选 换一换 GaussDB DWS 提供的gsql命令行客户端 它的运行环境是Linux操作系统 在使用gsql客户端远程连接GaussDB DWS 集群之前 需要准备一个Linux主机用于安装和运行gsql
  • 你真的知道运维是干嘛的吗?

    文章目录 前言 运维基本能力 运维岗位分类 按照职责划分 按照服务类型划分 按照运维模式划分 按照工作模式划分 按照管理层级划分 按照技术方向划分 按照服务对象划分 按照工作内容划分 按照服务形式划分 按照业务类型划分 按照技术栈划分 按照
  • python selenium页面跳转_Python爬虫之Selenium多窗口切换的实现

    前言 在页面操作过程中有时候点击某个链接会弹出新的窗口 但由于Selenium的所有操作都是在第一个打开的页面进行的 这时就需要主机切换到新打开的窗口上进行操作 WebDriver提供了switch to window 方法 可以实现在不同
  • 如何去实现机械灵巧手玩魔方和弹钢琴_工业级灵巧手与智慧抓取技术

    随着工业机器人的发展以及机器人应用领域的不断扩展 作为末端执行器的机器人夹爪的应用边界也在不断扩展 在工业自动化领域被广泛使用的气动手爪 正在被可精确控制 可数字化管理的新一代末端执行器所替代 编辑 符号整理 Cloud Sunny 机器人
  • “找不到或无法加载主类”该问题出现的一个可能原因

    今天按照教材上的程序 编译运行时 程序编译没有问题 但是运行时 出现 找不到或无法加载主类 的提示 遂网上四处找答案 说什么 1 拼写错误 2 环境变量配置时classpath和path前面未加 下面是我的程序 package myFram
  • vue——echarts柱状图横轴文字太多放不下【处理办法】

    1 如果单纯是文字太多 且中间无法分割开的话 可以采用两种方式 文字倾斜展示 效果 在options配置中的xAxis中配置如下代码 axisLabel interval 0 rotate 40 文字竖直显示 效果 在options配置中的
  • 终于搞定了SHADOWMAP,

    5 5pcf
  • 关于 ChatGPT 必看论文推荐【附论文链接】

    关于 ChatGPT 必看论文推荐 2022年11月 OpenAI推出人工智能聊天原型 ChatGPT 再次赚足眼球 为AI界引发了类似AIGC让艺术家失业的大讨论 ChatGPT 是一种专注于对话生成的语言模型 它能够根据用户的文本输入
  • 情人节用Python画玫瑰花

    用Python turtle 绘制的玫瑰花 效果图 import turtle import time turtle penup turtle setup 1100 1000 turtle hideturtle turtle speed 1
  • tensorRT模型推理时动态shape

    动态shape 所谓动态shape就是编译时指定可动态的范围 L H 推理时可以允许L lt shape lt H 在全卷积网络中我们通常就是有这个诉求的 推理时的shape是可以动态改变的 不一定要限制死 这个动态shape不一定只宽高