编译 MXNet 模型

2023-10-28

本篇文章译自英文文档 Compile MXNet Models

作者是 Joshua Z. ZhangKazutaka Morita

更多 TVM 中文文档可访问 →TVM 中文站

本文将介绍如何用 Relay 部署 MXNet 模型。

首先安装 mxnet 模块,可通过 pip 快速安装:

pip install mxnet --user

或参考官方安装指南:https://mxnet.apache.org/versions/master/install/index.html

# 一些标准的导包
import mxnet as mx
import tvm
import tvm.relay as relay
import numpy as np

从 Gluon Model Zoo 下载 Resnet18 模型

本节会下载预训练的 imagenet 模型,并对图像进行分类。

from tvm.contrib.download import download_testdata
from mxnet.gluon.model_zoo.vision import get_model
from PIL import Image
from matplotlib import pyplot as plt

block = get_model("resnet18_v1", pretrained=True)
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_name = "cat.png"
synset_url = "".join(
    [
        "https://gist.githubusercontent.com/zhreshold/",
        "4d0b62f3d01426887599d4f7ede23ee5/raw/",
        "596b27d23537e5a1b5751d2b0481ef172f58b539/",
        "imagenet1000_clsid_to_human.txt",
    ]
)
synset_name = "imagenet1000_clsid_to_human.txt"
img_path = download_testdata(img_url, "cat.png", module="data")
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
    synset = eval(f.read())
image = Image.open(img_path).resize((224, 224))
plt.imshow(image)
plt.show()

def transform_image(image):
    image = np.array(image) - np.array([123.0, 117.0, 104.0])
    image /= np.array([58.395, 57.12, 57.375])
    image = image.transpose((2, 0, 1))
    image = image[np.newaxis, :]
    return image

x = transform_image(image)
print("x", x.shape)

在这里插入图片描述
输出结果:

Downloading /workspace/.mxnet/models/resnet18_v1-a0666292.zip08d19deb-ddbf-4120-9643-fcfab19e7541 from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/resnet18_v1-a0666292.zip...
x (1, 3, 224, 224)

编译计算图

只需几行代码,即可将 Gluon 模型移植到可移植计算图上。mxnet.gluon 支持 MXNet 静态图(符号)和 HybridBlock。

shape_dict = {"data": x.shape}
mod, params = relay.frontend.from_mxnet(block, shape_dict)
## 添加 softmax 算子来提高概率
func = mod["main"]
func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs)

接下来编译计算图:

target = "cuda"
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(func, target, params=params)

输出结果:

/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  "target_host parameter is going to be deprecated. "

在 TVM 上执行可移植计算图

接下来用 TVM 重现相同的前向计算:

from tvm.contrib import graph_executor

dev = tvm.cuda(0)
dtype = "float32"
m = graph_executor.GraphModule(lib["default"](dev))
# 设置输入
m.set_input("data", tvm.nd.array(x.astype(dtype)))
# 执行
m.run()
# 得到输出
tvm_output = m.get_output(0)
top1 = np.argmax(tvm_output.numpy()[0])
print("TVM prediction top-1:", top1, synset[top1])

输出结果:

TVM prediction top-1: 282 tiger cat

使用带有预训练权重的 MXNet 符号

MXNet 常用 arg_params 和 aux_params 分别存储网络参数,下面将展示如何在现有 API 中使用这些权重:

def block2symbol(block):
    data = mx.sym.Variable("data")
    sym = block(data)
    args = {}
    auxs = {}
    for k, v in block.collect_params().items():
        args[k] = mx.nd.array(v.data().asnumpy())
    return sym, args, auxs

mx_sym, args, auxs = block2symbol(block)
# 通常将其保存/加载为检查点
mx.model.save_checkpoint("resnet18_v1", 0, mx_sym, args, auxs)
# 磁盘上有 "resnet18_v1-0000.params" 和 "resnet18_v1-symbol.json"

对于一般性 MXNet 模型:

mx_sym, args, auxs = mx.model.load_checkpoint("resnet18_v1", 0)
# 用相同的 API 来获取 Relay 计算图
mod, relay_params = relay.frontend.from_mxnet(mx_sym, shape_dict, arg_params=args, aux_params=auxs)
# 重复相同的步骤,用 TVM 运行这个模型

下载 Python 源代码:from_mxnet.py

下载 Jupyter Notebook:from_mxnet.ipynb

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

编译 MXNet 模型 的相关文章

  • 将 JSON 发布到 Python CGI

    我已经安装了 Apache2 并且 Python 可以工作 但我有一个问题 我有两页 一个是 Python 页面 另一个是带有 JQuery 的 Html 页面 有人可以告诉我如何让我的 ajax 帖子正常工作吗
  • 如何在 Debian 上的 virtualenv 中安装 numpy?

    注 参见这另一篇文章 https stackoverflow com questions 6442754 how to install h5py numpylibhdf5 as non root on a debian linux syst
  • 如何(重新)命名 pandas 数据框中的空列标题而不导出到 csv

    我有一个熊猫数据框df1带有一个索引列和一系列未命名的值 我想为未命名的系列指定一个名称 到目前为止 我知道的唯一方法是导出到df1 csv using df1 to csv df1 csv header Signal 然后使用以下命令重新
  • 通过 python 中的另外两个修改数组[重复]

    这个问题在这里已经有答案了 假设我们有三个一维数组 A 长度为 5 B 长度相同 示例中为5 C 更长 比如长度为 100 C最初用零填充 A给出索引C应更改的元素 它们可能会重复 以及B给出应添加到初始零的值C 例如 如果A 1 3 3
  • 来自 pandas 数据帧的烛台图,用日期替换索引

    此代码给出了带有移动平均线的烛台图 但 x 轴位于索引中 我需要 x 轴位于日期中 需要做什么改变 import numpy as np import pandas as pd import matplotlib pyplot as plt
  • 雅虎财务请求功能出现 404 客户端错误

    yahoo Financials的请求功能出现404 Client Error 直接点击以下网址没有问题 https finance yahoo com quote AAPL financials p AAPL https finance
  • 将 Python Pandas DataFrame 写入 Word 文档

    我正在努力创建一个使用 Pandas DataFrames 的 Python 生成的报告 目前我正在使用DataFrame to string 方法 但是 这会作为字符串写入文件 有没有办法让我实现这一目标 同时将其保留为表格 以便我可以使
  • 带有 mkdocs 的本地 mathjax

    我想在无法访问互联网的计算机上使用 MathJax 和 Mkdocs 因此我不能只调用 Mathjax CDN Config mkdocs yml site name My Docs extra javascript javascripts
  • 在linux上安装python ssl模块,无需重新编译

    是否可以在已经安装了 OpenSSL 的 Linux 机器上安装 python 的 SSL 模块 而无需重新编译 python 我希望它就像复制几个文件并将它们包含在库路径中一样简单 Python版本是2 4 3 谢谢 是否可以在已经安装了
  • 用 Python 绘制直方图

    我有两个列表 x 和 y x 包含字母表 A Z Y 包含它们在文件中的频率 我尝试研究如何在直方图中绘制这些值 但在理解如何绘制它方面没有成功 n bins patches plt hist x 26 normed 1 facecolor
  • 正在使用 PIL 保存损坏的图像

    我遇到一个问题 操作图像像素导致保存损坏的图像 因此 我使用 PIL 打开图像 然后将其转换为 NumPy 数组 image Image open myimage png np image np asarray image 然后 我转置图像
  • 数据框中 .map(str) 和 .astype(str) 有什么区别

    我有一个数据框 其列名为 col1 和 col2 的整数类型条目 我想将 col1 和 col2 的条目以及其间的 点 连接起来 我搜索并发现添加两个列条目 df col df col1 map str df col2 map str 并添
  • Python-验证我的文档 xls 中是否存在工作表

    我正在尝试在空闲时间设计一个小程序 加载 xls 文件 然后在要扫描的文档中选择一张纸 步骤1 用户导入 xls文件 导入程序后检查文件是否存在 我能做到的 第 2 步 我要求用户提供要分析的文档表 xls 的名称 这就是它停止的地方 该程
  • 无法导入QUERY_TERMS

    我正在运行一个网站Python and Django Django filters 2 1 installed Django 2 1 installed 当我运行时 我收到以下错误 importError Could not import
  • 如何展平解析树并存储在字符串中以进行进一步的字符串操作 python nltk

    我正在尝试从树结构中获取扁平树 如下所示 我想将整个树放在一个字符串中 就像没有检测到坏树错误一样 S NP SBJ NP DT The JJ high JJ seven day PP IN of NP DT the CD 400 NNS
  • Python 相当于 Scala 案例类

    Python 中是否有与 Scala 的 Case Class 等效的东西 就像自动生成分配给字段而无需编写样板的构造函数一样 当前执行此操作的现代方法 从 Python 3 7 开始 是使用数据类 https www python org
  • 使用 pandas 单元格中列表的长度选择行[重复]

    这个问题在这里已经有答案了 我有一张表 df a b c 1 x y x 2 x z c d 3 x t e f g 只是想知道如何使用 c 列的长度选择行 such as df loc len df c gt 1 我知道这是不对的 正确的
  • 如何使用 Python/Django 在 Facebook 中获取(和使用)扩展权限

    我正在尝试编写一个简单的应用程序 让用户授予我的代码写入其页面的 Facebook 流的权限 据我了解 它应该很简单 让用户单击一个按钮 启动一个弹出窗口 其中包含我的 Facebook 应用程序中的页面 在该页面中 他们单击授予的内容流发
  • 如何同时接受int和float类型的输入?

    我正在制作一个货币转换器 如何让 python 同时接受整数和浮点数 我就是这样做的 def aud brl amount From to ER 0 42108 if amount int if From strip aud and to
  • 基于值的 matplotlib 条形图颜色

    有没有一种方法可以根据条形图的值对条形图的条形进行着色 例如 values below 0 5 red values between 0 5 to 0 green values between 0 to 08 blue etc 我找到了一些

随机推荐