TensorRT学习(实战-自定义算子)

2023-11-18

YOLOv4进行TensorRT推理的时候会使用Mish激活函数,而使用到的mish激活函数没有在TensorRT进行实现。故需要进行实现对应TensorRT插件,故需要进行Mish激活函数的实现。

Mish激活函数定义

def mish_fun(x):
    tmp = np.log(1 + np.exp(x))
    tmp = np.tanh(tmp)
    tmp = tmp * x
    return tmp

如上所示激活函数的表达式,该激活还是比较复杂的,需要实现对应的tensorRT插件

tensorRT 插件

使用C++ API添加自定义图层
您可以通过从TensorRT的插件基类之一派生来实现自定义层。
从插件的一个基类派生插件类。它们在支持具有不同类型/格式的I/O或具有动态形状的网络方面具有不同的表达能力。下表总结了基类,按表达性从低到高的顺序排列。
注意:如果插件是用于一般用途,请提供FP32实现,以便允许它在任何网络上正常运行。

Table 3. Base Classes, Ordered from Least Expressive to Most Expressive
Introduced in TensorRT version? Mixed I/O formats/types Dynamic shapes? Supports implicit/explicit batch mode?
IPluginV2Ext 5.1 Limited No Both implicit and explicit batch modes
IPluginV2IOExt 6.0.1 General No Both implicit and explicit batch modes
IPluginV2DynamicExt 6.0.1 General Yes Explicit batch mode only

为了在网络中使用插件,您必须首先在TensorRT的PluginRegistry(C++,Python)中注册它。不是直接注册插件,而是注册插件的工厂类的实例,从PluginCreator(C++,Python)派生。plugin creator类还提供了有关插件的其他信息:其名称、版本和插件字段参数。

有两种方法可以在注册表中注册插件:

TensorRT提供了一个宏REGISTER_TENSORT_PLUGIN,用于在注册表中静态注册插件创建者。请注意,NREGISTER_TENSORT_PLUGI始终在默认名称空间(“”)下注册创建者。
通过创建您自己的入口点(类似于initLibNvInferPlugins)并在插件注册表上调用registerCreator来registerCreator。这比静态注册更好,因为它提供了潜在的更低的内存占用,并允许插件在唯一的名称空间下注册。这确保了在不同插件库之间的构建时间期间没有名称冲突。
用IPluginCreator::createPlugin()返回一个IPluginV2类型的插件对象。您可以使用addPluginV2()将插件添加到TensorRT网络,这将使用给定的插件创建网络层。

例如,您可以向网络添加插件层,如下所示:

// Look up the plugin in the registry
auto creator = getPluginRegistry()->getPluginCreator(pluginName, pluginVersion);
const PluginFieldCollection* pluginFC = creator->getFieldNames();
// Populate the fields parameters for the plugin layer 
// PluginFieldCollection *pluginData = parseAndFillFields(pluginFC, layerFields); 
// Create the plugin object using the layerName and the plugin meta data
IPluginV2 *pluginObj = creator->createPlugin(layerName, pluginData);
// Add the plugin to the TensorRT network 
auto layer = network.addPluginV2(&inputs[0], int(inputs.size()), pluginObj);
… (build rest of the network and serialize engine)
// Destroy the plugin object
pluginObj->destroy()
… (free allocated pluginData)


注意:前面描述的createPlugin方法在堆上创建了一个新的插件对象,并返回一个指向它的指针。确保销毁pluginObj,如前所示,以避免内存泄漏。
在序列化期间,TensorRT引擎在内部存储所有IPluginV2类型插件的插件类型、插件版本和命名空间(如果存在)。在反序列化过程中,TensorRT从插件注册表中查找插件创建者,并调用IPluginCreator::deserializePlugin()。当引擎被删除时,引擎通过调用IPluginV2::destroy()方法销毁在引擎构建期间创建的插件对象的克隆。您有责任确保您创建的插件对象在添加到网络后被释

注:

不要序列化所有插件参数:只有那些插件在运行时正确运行所需的。可以省略生成时间参数。
以相同的顺序序列化和反序列化插件参数。在反序列化过程中,验证插件参数是否初始化为默认值或反序列化值。未初始化的参数会导致未定义的行为。
如果您是汽车安全用户,则必须调用getSafePluginRegistry()而不是getPluginRegistry()。还必须使用REGISTER_SAFE_TENSORT_PLUGIN宏,而不是REGISTER_TENSORT_PLUGIN。

关键类说明

IPluginV2Ext

用户实现层的插件创建类。

virtual nvinfer1::DataType  getOutputDataType (int32_t index, nvinfer1::DataType const *inputTypes, int32_t nbInputs) const noexcept=0
返回请求索引处插件输出的DataType, More...
virtual bool  isOutputBroadcastAcrossBatch (int32_t outputIndex, bool const *inputIsBroadcasted, int32_t nbInputs) const noexcept=0
如果输出张量在批处理中广播,则返回true。. More...
virtual bool  canBroadcastInputAcrossBatch (int32_t inputIndex) const noexcept=0
如果插件可以使用跨批广播的输入而无需复制,则返回true. More...
virtual void  configurePlugin (Dims const *inputDims, int32_t nbInputs, Dims const *outputDims, int32_t nbOutputs, DataType const *inputTypes, DataType const *outputTypes, bool const *inputIsBroadcast, bool const *outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept=0
使用输入和输出数据类型配置图层. More...
IPluginV2Ext ()=default
~IPluginV2Ext () override=default
virtual void  attachToContext (cudnnContext *, cublasContext *, IGpuAllocator *) noexcept
将插件对象附加到执行上下文,并授予插件对某些上下文资源的访问权限 More...
virtual void  detachFromContext () noexcept
从执行上下文中分离插件对象. More...
IPluginV2Ext *  clone () const noexcept override=0

克隆插件对象。这也会复制内部插件参数,并返回一个带有这些参数的新插件对象。如果源插件预先配置了configurePlugin(),返回的对象也应该是预先配置好的。返回的对象应该允许连接到上下文() 克隆的插件对象可以与源对象共享相同的每引擎不可变资源(例如,权重)(例如,经由引用计数)以避免重复

 Public Member Functions inherited from nvinfer1::IPluginV2
virtual AsciiChar const *  getPluginType () const noexcept=0
返回插件类型。应与相应插件创建者返回的插件名称匹配。 More...
virtual AsciiChar const *  getPluginVersion () const noexcept=0
返回插件版本。应该与相应插件创建者返回的插件版本相匹配r. More...
virtual int32_t  getNbOutputs () const noexcept=0
获取层的输出数量 More...
virtual Dims  getOutputDimensions (int32_t index, Dims const *inputs, int32_t nbInputDims) noexcept=0
获取输出张量的维度. More...
virtual bool  supportsFormat (DataType type, PluginFormat format) const noexcept=0
检查格式支持 More...
virtual int32_t  initialize () noexcept=0
I初始化要执行的层。这在引擎创建时调用More...
virtual void  terminate () noexcept=0
释放插件层初始化过程中获取的资源。这叫engine被毁 More...
virtual size_t  getWorkspaceSize (int32_t maxBatchSize) const noexcept=0
    查找图层所需的工作空间大小. More...
virtual int32_t  enqueue (int32_t batchSize, void const *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept=0
执行图层 More...
virtual size_t  getSerializationSize () const noexcept=0
    查找所需的序列化缓冲区的大小 More...
virtual void  serialize (void *buffer) const noexcept=0
    序列化图层r. More...
virtual void  destroy () noexcept=0
    销毁插件对象。这将在网络、构建器或引擎被销毁时调用. More...
virtual void  setPluginNamespace (AsciiChar const *pluginNamespace) noexcept=0
设置此插件对象所属的命名空间。理想情况下,同一插件库中的所有插件对象都应该具有相同的命名空间. More...
virtual AsciiChar const *  getPluginNamespace () const noexcept=0
返回插件对象的命名空间. More...

Protected Member Functions

int32_t  getTensorRTVersion () const noexcept override
返回构建此插件的API版本。TensorRT保留的高位字节,用于区分此插件和IPluginV2IPluginV2More...
void  configureWithFormat (Dims const *, int32_t, Dims const *, int32_t, DataTypePluginFormat, int32_t) noexcept override
派生类不应该实现这个。在C++11 API中,它将被override final. More...

IPluginCreator

用户实现层的插件创建器类。

virtual int32_t  getTensorRTVersion () const noexcept
返回插件创建者编译时使用的API版本. More...
virtual AsciiChar const *  getPluginName () const noexcept=0
返回插件名称。. More...
virtual AsciiChar const *  getPluginVersion () const noexcept=0
    返回插件版本. More...
virtual PluginFieldCollection const *  getFieldNames () noexcept=0
返回需要传递给createPlugin的字段列表. More...
virtual IPluginV2 *  createPlugin (AsciiChar const *name, PluginFieldCollection const *fc) noexcept=0
返回一个插件对象。错误时返回nullptr. More...
virtual IPluginV2 *  deserializePlugin (AsciiChar const *name, void const *serialData, size_t serialLength) noexcept=0
    在插件层的反序列化过程中调用。返回插件对象
virtual void  setPluginNamespace (AsciiChar const *pluginNamespace) noexcept=0
根据插件所属的插件库设置插件创建者的命名空间。这可以在注册插件创建者时设置。 More...
virtual AsciiChar const *  getPluginNamespace () const noexcept=0
返回插件创建者对象的命名空间。 More...
IPluginCreator ()=default
virtual  ~IPluginCreator ()=default

Mish激活函数

#ifndef _MISH_PLUGIN_H
#define _MISH_PLUGIN_H

#include <string>
#include <vector>
#include "NvInfer.h"

namespace nvinfer1
{
    class MishPlugin: public IPluginV2IOExt
    {
        public:
            // 显示的构造函数
            explicit MishPlugin();
            // 构造函数
            MishPlugin(const void* data, size_t length);
            ~MishPlugin();
            // 返回plugin的输出数量
            int getNbOutputs() const override
            {
                return 1;
            }
            // 输出张量的输出的维度
            Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
            // 初始化要执行的层。这在引擎创建时调用
            int initialize() override;
            // 释放插件层初始化过程中获取的资源
            virtual void terminate() override {};
            // 获得该层工作空间的大小
            virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0;}
            // 执行该的层的处理
            virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override;
            //  查找所需的序列化缓冲区的大小
            virtual size_t getSerializationSize() const override;
            // 进行对应的序列化
            virtual void serialize(void* buffer) const override;

            bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override {
                return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
            }

            const char* getPluginType() const override;

            const char* getPluginVersion() const override;
            //   销毁插件对象。这将在网络、构建器或引擎被销毁时调用
           void destroy() override;
            // 克隆对象
            IPluginV2IOExt* clone() const override;
            // 这是对象所属的命名空间
            void setPluginNamespace(const char* pluginNamespace) override;
            // 返回对象的命名空间
            const char* getPluginNamespace() const override;
            // 返回对象输出时间
            DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
            // 是否进行舒畅张量在批处理中广播
            bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override;
            // 如果插件可以使用跨批广播的输入而无需刻意的复制
            bool canBroadcastInputAcrossBatch(int inputIndex) const override;
            // 将插件对象附加到执行上下文,
            void attachToContext(
                    cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override;
            // 使用输入和输出数据类型配置网络层
            void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override;
            // 从执行上下文中分离插件对象.
            void detachFromContext() override;

            int input_size_;
        private:
            void forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize = 1);
            int thread_count_ = 256;
            const char* mPluginNamespace;
    };

    class MishPluginCreator : public IPluginCreator
    {
        public:
            // 构造函数
            MishPluginCreator();
            // 析构函数
            ~MishPluginCreator() override = default;
            // 获得插件的名字
            const char* getPluginName() const override;
            // 获得插件的版本
            const char* getPluginVersion() const override;
            // 返回需要传递给createPlugin的字段列表
            const PluginFieldCollection* getFieldNames() override;
            // 返回一个对应的插件对象
            IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override;
            // 在插件层的反序列化过程中调用。返回插件对象
            IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
            // 根据插件所属的插件库设置插件创建者的命名空间
            void setPluginNamespace(const char* libNamespace) override
            {
                mNamespace = libNamespace;
            }
            // 返回插件创建者对象的命名空间
            const char* getPluginNamespace() const override
            {
                return mNamespace.c_str();
            }

        private:
            std::string mNamespace;
            static PluginFieldCollection mFC;
            static std::vector<PluginField> mPluginAttributes;
    };
    REGISTER_TENSORRT_PLUGIN(MishPluginCreator);
};
#endif 

#include <cmath>
#include <stdio.h>
#include <cassert>
#include <iostream>
#include "mish.h"

namespace nvinfer1
{
    MishPlugin::MishPlugin()
    {
    }

    MishPlugin::~MishPlugin()
    {
    }

    // create the plugin at runtime from a byte stream
    MishPlugin::MishPlugin(const void* data, size_t length)
    {
        assert(length == sizeof(input_size_));
        input_size_ = *reinterpret_cast<const int*>(data);
    }

    void MishPlugin::serialize(void* buffer) const
    {
        *reinterpret_cast<int*>(buffer) = input_size_;
    }

    size_t MishPlugin::getSerializationSize() const
    {  
        return sizeof(input_size_);
    }

    int MishPlugin::initialize()
    { 
        return 0;
    }

    Dims MishPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)
    {
        assert(nbInputDims == 1);
        assert(index == 0);
        input_size_ = inputs[0].d[0] * inputs[0].d[1] * inputs[0].d[2];
        // Output dimensions
        return Dims3(inputs[0].d[0], inputs[0].d[1], inputs[0].d[2]);
    }

    // Set plugin namespace
    void MishPlugin::setPluginNamespace(const char* pluginNamespace)
    {
        mPluginNamespace = pluginNamespace;
    }

    const char* MishPlugin::getPluginNamespace() const
    {
        return mPluginNamespace;
    }

    // Return the DataType of the plugin output at the requested index
    DataType MishPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const
    {
        return DataType::kFLOAT;
    }

    // Return true if output tensor is broadcast across a batch.
    bool MishPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const
    {
        return false;
    }

    // Return true if plugin can use input that is broadcast across batch without replication.
    bool MishPlugin::canBroadcastInputAcrossBatch(int inputIndex) const
    {
        return false;
    }

    void MishPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput)
    {
    }

    // Attach the plugin object to an execution context and grant the plugin the access to some context resource.
    void MishPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator)
    {
    }

    // Detach the plugin object from its execution context.
    void MishPlugin::detachFromContext() {}

    const char* MishPlugin::getPluginType() const
    {
        return "Mish_TRT";
    }

    const char* MishPlugin::getPluginVersion() const
    {
        return "1";
    }

    void MishPlugin::destroy()
    {
        delete this;
    }

    // Clone the plugin
    IPluginV2IOExt* MishPlugin::clone() const
    {
        MishPlugin *p = new MishPlugin();
        p->input_size_ = input_size_;
        p->setPluginNamespace(mPluginNamespace);
        return p;
    }

    __device__ float tanh_activate_kernel(float x){return (2/(1 + expf(-2*x)) - 1);}

    __device__ float softplus_kernel(float x, float threshold = 20) {
        if (x > threshold) return x;                // too large
        else if (x < -threshold) return expf(x);    // too small
        return logf(expf(x) + 1);
    }

    __global__ void mish_kernel(const float *input, float *output, int num_elem) {

        int idx = threadIdx.x + blockDim.x * blockIdx.x;
        if (idx >= num_elem) return;

        //float t = exp(input[idx]);
        //if (input[idx] > 20.0) {
        //    t *= t;
        //    output[idx] = (t - 1.0) / (t + 1.0);
        //} else {
        //    float tt = t * t;
        //    output[idx] = (tt + 2.0 * t) / (tt + 2.0 * t + 2.0);
        //}
        //output[idx] *= input[idx];
        output[idx] = input[idx] * tanh_activate_kernel(softplus_kernel(input[idx]));
    }

    void MishPlugin::forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize) {
        int block_size = thread_count_;
        int grid_size = (input_size_ * batchSize + block_size - 1) / block_size;
        mish_kernel<<<grid_size, block_size>>>(inputs[0], output, input_size_ * batchSize);
    }

    int MishPlugin::enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream)
    {
        //assert(batchSize == 1);
        //GPU
        //CUDA_CHECK(cudaStreamSynchronize(stream));
        forwardGpu((const float *const *)inputs, (float*)outputs[0], stream, batchSize);
        return 0;
    }

    PluginFieldCollection MishPluginCreator::mFC{};
    std::vector<PluginField> MishPluginCreator::mPluginAttributes;

    MishPluginCreator::MishPluginCreator()
    {
        mPluginAttributes.clear();

        mFC.nbFields = mPluginAttributes.size();
        mFC.fields = mPluginAttributes.data();
    }

    const char* MishPluginCreator::getPluginName() const
    {
            return "Mish_TRT";
    }

    const char* MishPluginCreator::getPluginVersion() const
    {
            return "1";
    }

    const PluginFieldCollection* MishPluginCreator::getFieldNames()
    {
            return &mFC;
    }

    IPluginV2IOExt* MishPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
    {
        MishPlugin* obj = new MishPlugin();
        obj->setPluginNamespace(mNamespace.c_str());
        return obj;
    }

    IPluginV2IOExt* MishPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength)
    {
        // This object will be deleted when the network is destroyed, which will
        // call MishPlugin::destroy()
        MishPlugin* obj = new MishPlugin(serialData, serialLength);
        obj->setPluginNamespace(mNamespace.c_str());
        return obj;
    }

}

将该对象添加到模型中

Weights emptywts{DataType::kFLOAT, nullptr, 0};
    //卷积层处理
        //!
    //! \brief Add a multi-dimension convolution layer to the network.
    //!
    //! \param  The ininputput tensor to the convolution.
    //! \param nbOutputMaps The number of output feature maps for the convolution.
    //! \param kernelSize The multi-dimensions of the convolution kernel.
    //! \param kernelWeights The kernel weights for the convolution.
    //! \param biasWeights The optional bias weights for the convolution.
    // IConvolutionLayer* addConvolutionNd(
    //     ITensor& input, int32_t nbOutputMaps, Dims kernelSize, Weights kernelWeights, Weights biasWeights) noexcept
    // {
    //     return mImpl->addConvolutionNd(input, nbOutputMaps, kernelSize, kernelWeights, biasWeights);
    // }
    IConvolutionLayer* conv1 = network->addConvolutionNd(input, outch, DimsHW{ksize, ksize}, weightMap["module_list." + std::to_string(linx) + ".Conv2d.weight"], emptywts);
    assert(conv1);
    // 设置对应参数
    conv1->setStrideNd(DimsHW{s, s});
    conv1->setPaddingNd(DimsHW{p, p});

    // 设置对应的批量归一化数据
    IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), "module_list." + std::to_string(linx) + ".BatchNorm2d", 1e-4);

    // 创建对应的mish激活函数
    auto creator = getPluginRegistry()->getPluginCreator("Mish_TRT", "1");
    const PluginFieldCollection* pluginData = creator->getFieldNames();
    IPluginV2 *pluginObj = creator->createPlugin(("mish" + std::to_string(linx)).c_str(), pluginData);
    ITensor* inputTensors[] = {bn1->getOutput(0)};
    auto mish = network->addPluginV2(&inputTensors[0], 1, *pluginObj);
    return mish;
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorRT学习(实战-自定义算子) 的相关文章

随机推荐