【ONNX】pytorch模型导出成ONNX格式:支持多参数与动态输入

2023-11-13

        pytorch格式的模型在部署之前一般需要做格式转换。本文介绍了如何将pytorch格式的模型导出到ONNX格式的模型。ONNX(Open Neural Network Exchange)格式是一种常用的开源神经网络格式,被较多推理引擎支持,比如:ONNXRuntime, Intel OpenVINO, TensorRT等。

1. 网络结构定义        

我们以一个Image Super Resolution的模型为例。首先,需要知道模型的网络定义SuperResolutionNet,并创建模型对象torch_model:

# Super Resolution model definition in PyTorch
import torch.nn as nn
import torch.nn.init as init

class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)
        init.zeros_(self.conv4.bias)

# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)

2. 加载模型文件

        Pytorch模型的参数信息存储在state_dict中。 state_dict是一个Python字典结构的对象,里面存储了神经网络中每层对应的参数张量。将每层的参数结构以及最后一层的bias打印出来:

def print_state_dict(state_dict):    
    print(len(state_dict))
    for layer in state_dict:
        print(layer, '\t', state_dict[layer].shape)
    print(state_dict['conv4.bias'])
print_state_dict(model.state_dict())

输出:

8
conv1.weight      torch.Size([64, 1, 5, 5])
conv1.bias      torch.Size([64])
conv2.weight      torch.Size([64, 64, 3, 3])
conv2.bias      torch.Size([64])
conv3.weight      torch.Size([32, 64, 3, 3])
conv3.bias      torch.Size([32])
conv4.weight      torch.Size([9, 32, 3, 3])
conv4.bias      torch.Size([9])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.])

        因为之前将第四层的bias初始化为0,所以输出是全零。然后调用load_state_dict加载模型文件,可以看到加载之后参数的变化。eval将模型设置为推理状态。

# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'

model.load_state_dict(model_zoo.load_url(model_url))

print_state_dict(model.state_dict())
# set the model to inference mode
model.eval()

输出

8
conv1.weight      torch.Size([64, 1, 5, 5])
conv1.bias      torch.Size([64])
conv2.weight      torch.Size([64, 64, 3, 3])
conv2.bias      torch.Size([64])
conv3.weight      torch.Size([32, 64, 3, 3])
conv3.bias      torch.Size([32])
conv4.weight      torch.Size([9, 32, 3, 3])
conv4.bias      torch.Size([9])
tensor([-0.0151, -0.0191, -0.0362, -0.0224,  0.0548,  0.0113,  0.0529,  0.0258,
        -0.0180])
SuperResolutionNet(
  (relu): ReLU()
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pixel_shuffle): PixelShuffle(upscale_factor=3)
)


3. 输出成ONNX格式

        在调用torch.onnx.export之前,需要先创建输入数据。因为模型的导出实际上是执行了一次推理过程。在执行的过程中记录使用到的操作。输入数据可以是随机的:

# Input to the model
x = torch.randn(1, 1, 224, 224, requires_grad=True)
# Export the model
torch.onnx.export(model,               # model being run
                  x,                         # model input 
                  "D:\\super_resolution.onnx",   # where to save the model (can be a file or file-like object)                  
                  opset_version=11,          # the ONNX version to export the model to                  
                  input_names = ['input'],   # the model's input names
                  output_names = ['output']  # the model's output names
                  )

        export的第一个参数是模型对象,第二个参数是输入数据,第三个参数是输出的模型文件名,这三个参数是必须指定的。还有一些常用的可选参数:

        opset_version, 指定的操作版本,一般越高的版本会支持更多的操作。如果遇到某个操作不支持,可以将版本号设置的高一点试试。
        input_names, 输入参数名。如果不指定,会使用默认名字。
        output_names, 输出参数名。如果不知道,会使用默认名字。
        输出成功后,可以使用Netron查看网络结构。Netron是一个开源的神经网络模型可视化工具,可以使用在线网页版的https://netron.app/,或者下载安装桌面版的https://github.com/lutzroeder/netron。打开导出的模型,结构如下:

在这里插入图片描述

4. 导出动态输入模型

        可以看到上面导出的模型输入是固定的1 x 1 x 224 x 224输出是固定的1 x 1 x 672 x 672.实际应用的时候输入图片的尺寸是不固定的,而且可能一次输入多种图片一起处理。我们可以通过指定dynamic_axes参数来导出动态输入的模型。dynamic_axes的参数是一个字典类型,字典的key就是输入或者输出的名字,对应key的value可以是一个字典或者列表,指定了输入或者输出的index以及对应的名字。比如想要让输入的index为0的维度表示动态的batch_size那么就指定{0: 'batch_size'}。同样的方法可以指定宽高所在的维度输出成动态的。

input_name = 'input'
output_name = 'output'
torch.onnx.export(model,               # model being run
                  x,                         # model input 
                  "D:\\super_resolution_2.onnx",   # where to save the model (can be a file or file-like object)                  
                  opset_version=11,          # the ONNX version to export the model to                  
                  input_names = [input_name],   # the model's input names
                  output_names = [output_name],  # the model's output names
                  dynamic_axes= {
                        input_name: {0: 'batch_size', 2 : 'in_width', 3: 'int_height'},
                        output_name: {0: 'batch_size', 2: 'out_width', 3:'out_height'}}
                  )

输出的模型使用Netron打开,结构如下:

在这里插入图片描述

        查看输入输出信息可以看到,输入的维度变成:[batch_size,1,in_width,int_height],输出的维度变成:[batch_size,1,out_width,out_height]。表示这个模型可以接收动态的批次大小和宽高尺寸。

在这里插入图片描述

5. 多参数输入

5.1 多参数输入模型的导出

        有时候可能会遇到比较复杂的模型,推理时需要输入多个参数的情况。我们可以通过将参数列表包在一个list中来输出ONNX模型。我们先将模型的forward方法修改一下,增加一个输入参数scale:

class SuperResolutionNet2(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet2, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x, scale):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)
        init.zeros_(self.conv4.bias)  

# Create the super-resolution model by using the above model definition.
model2 = SuperResolutionNet2(upscale_factor=3)

调用export输出到ONNX:

input_name = 'input'
output_name = 'output'
torch.onnx.export(model2,               
                  (x, 2),                         
                  "D:\\super_resolution_3.onnx",   
                  opset_version=11,          
                  input_names = [input_name],  
                  output_names = [output_name],
                  dynamic_axes= {
                        input_name: {0: 'batch_size', 2 : 'in_width', 3: 'int_height'},
                        output_name: {0: 'batch_size', 2: 'out_width', 3:'out_height'}}
                  )

5.2 易错点

        由于export函数的机制,会把模型输入的参数自动转换成tensor类型,比如上面的scale参数,虽然传入的时候是int32类型,但是export在执行时会调用到forward函数,此时scale已经变成一个tensor类型。我们可以做个测试,打印一下scale的类型来验证:

def forward(self, x, scale):
    print(scale)
    x = self.relu(self.conv1(x))
    x = self.relu(self.conv2(x))
    x = self.relu(self.conv3(x))
    x = self.pixel_shuffle(self.conv4(x))
    return x

重新运行export后输出:

tensor(2)
        这种机制带来的影响是,在使用scale参数时可能需要做一个转换,比如转换成float类型。否则某些函数的调用会失败。以插值函数为例做个测试,将forward修改一下:

def forward(self, x, scale):
    print(scale)        
    y = F.interpolate(x, scale_factor= 1./scale, mode="bilinear")
    x = self.relu(self.conv1(x))
    x = self.relu(self.conv2(x))
    x = self.relu(self.conv3(x))
    x = self.pixel_shuffle(self.conv4(x))
    return x

这个时候运行export会报错,因为插值函数的scale_factor参数不能是一个tensor类型。修改后的正确版本:

def forward(self, x, scale):
    print(scale)        
    y = F.interpolate(x, scale_factor= 1./float(scale), mode="bilinear")
    x = self.relu(self.conv1(x))
    x = self.relu(self.conv2(x))
    x = self.relu(self.conv3(x))
    x = self.pixel_shuffle(self.conv4(x))
    return x

6. 完整代码

在这里 https://github.com/jb2020-super/pytorch-utils/blob/main/to_onnx_ex.ipynb

7. 参考

https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
https://pytorch.org/docs/stable/onnx.html?highlight=export#torch.onnx.export
https://onnxruntime.ai/docs/get-started/with-python.html
————————————————
Thanks to:https://blog.csdn.net/superbinlovemiaomi/article/details/121344667

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【ONNX】pytorch模型导出成ONNX格式:支持多参数与动态输入 的相关文章

  • 翠儿。让流永远运行

    我对 tweepy python 库比较陌生 我想确保我的流 python 脚本始终在远程服务器上运行 因此 如果有人能够分享如何实现这一目标的最佳实践 那就太好了 现在我正在这样做 if name main while True try
  • 将tensorflow 2.0 BatchDataset转换为numpy数组

    我有这个代码 train images test images tf keras datasets mnist load data train dataset tf data Dataset from tensor slices train
  • 行未从树视图复制

    该行未在树视图中复制 我在按行并复制并粘贴到未粘贴的任何地方后制作了弹出复制 The code popup tk Menu tree opportunity tearoff 0 def row copy item tree opportun
  • Tensorflow 可变图像输入大小(自动编码器、放大......)

    Edit WARNING不建议使用不同图像大小的图像 因为张量需要具有相同的大小才能实现并行化 我一直在寻找解决方案 了解如何使用不同大小的图像作为神经网络的输入 Numpy 第一个想法是使用numpy 然而 由于每个图像的大小不同 我无法
  • Python 3 __getattribute__ 与点访问行为

    我读了一些关于 python 的对象属性查找的内容 这里 https blog ionelmc ro 2015 02 09 understanding python metaclasses object attribute lookup h
  • Python BeautifulSoup XML 解析

    我编写了一个简单的脚本来使用 BeautifulSoup 模块解析 XML 聊天日志 标准 soup prettify 工作正常 只是聊天日志中有很多绒毛 您可以在下面看到我正在使用的脚本代码和一些 XML 输入文件 Code import
  • 为什么 Python 中的“pip install”会引发语法错误?

    我正在尝试使用 pip 安装软件包 我试着跑pip install从Python shell 但我得到了SyntaxError 为什么我会收到此错误 如何使用 pip 安装软件包 gt gt gt pip install selenium
  • 如何限制Django CreateView中ForeignKey字段的选择?

    我有一个沿着这些思路的模型结构 models py class Foo models Model class Bar models Model foo models ForeignKey Foo class Baz models Model
  • Python igraph:从图中删除顶点

    我正在使用安然电子邮件数据集 并尝试删除没有 enron com 的电子邮件地址 即我只想拥有安然电子邮件 当我尝试删除那些没有 enron com 的地址时 一些电子邮件由于某些原因被跳过 下面显示了一个小图 其中顶点是电子邮件地址 这是
  • 更改 pandas 中多个日期时间列的时区信息

    有没有一种简单的方法可以将数据帧中的所有时间戳列转换为本地 任何时区 不是逐列进行吗 您可以有选择地将转换应用于所有日期时间列 首先 选择它们select dtypes https pandas pydata org pandas docs
  • 获取列表中倒数第二个元素[重复]

    这个问题在这里已经有答案了 我可以通过以下方式获取列表的倒数第二个元素 gt gt gt lst a b c d e f gt gt gt print lst len lst 2 e 有没有比使用更好的方法print lst len lst
  • 将 Pandas 列中的列表拆分为单独的列

    这是我在 pandas 数据框中的 特征 列 Feature Cricket 82379 Kabaddi 255 Reality 4751 Cricket 15640 Wildlife 730 LiveTV 13 Football 4129
  • Django 在选择列表更改时创建毫无意义的迁移

    我正在尝试使用可调用创建一个带有选择字段的模型 以便 Django 在选择列表更改时不会创建迁移 如中所述this https stackoverflow com questions 31788450 stop django from cr
  • conda-env list / conda info --envs 如何查找环境?

    我一直在尝试 anaconda miniconda 因为我的用户使用随 miniconda 安装的结构生物学程序 并且作者都没有 A 考虑到可能存在其他 miniconda 应用程序 B 他们的程序将在多用户环境中使用 因此 使用 Arch
  • Airflow Python 单元测试?

    我想为我们的 DAG 添加一些单元测试 但找不到任何单元测试 有 DAG 单元测试框架吗 有一个端到端的测试框架存在 但我猜它已经死了 https issues apache org jira browse AIRFLOW 79 https
  • Flask WTForms 使用变量自动填充 StringField

    我有一个表格 我想用上一页收到的信息自动填充一些字段 但如果他们想调整它 它需要是可更改的 我正在为我的 SelectField 使用动态创建的列表 但添加 StringField 并不成功 请参阅下面的我的代码 forms py clas
  • 如何将列表字典写入字符串而不是 CSV 文件?

    This 堆栈溢出问题 https stackoverflow com questions 37997085 how to write a dictionary of lists to a csv file将列表字典写入 CSV 文件的答案
  • 在 Python 模块中使用 InstaLoader

    我正在尝试使用 Instaloader 下载与主题标签相关的照片以进行图像分析 我在GitHub存储库中找到了一个全面的方法 如何在终端中执行它 但是 我需要将脚本集成到Python笔记本中 这是脚本 instaloader no vide
  • Django South - 将 null=True 字段转换为 null=False 字段

    我的问题是 转变的最佳做法是什么null True场变成null False使用 Django South 的字段 具体来说 我正在与ForeignKey 你应该先写一个数据迁移 http south aeracode org docs t
  • 无法在 Windows 10 上构建 Detectron2

    尽管 Windows 上的 Detectron2 没有官方支持 但有很多可用的说明 我尝试按照这些说明进行操作 但最终出现了相同的错误 这是我的设置 OS Windows 10 专业版 19043 1466 微软视觉工作室 2019 CUD

随机推荐