编译 Keras 模型

2023-10-27

本篇文章译自英文文档 Compile Keras Models

作者是 Yuwei Hu

更多 TVM 中文文档可访问 →TVM 中文站

本文介绍如何用 Relay 部署 Keras 模型。

首先安装 Keras 和 TensorFlow,可通过 pip 快速安装:

pip install -U keras --user
pip install -U tensorflow --user

或参考官网:https://keras.io/#installation

import tvm
from tvm import te
import tvm.relay as relay
from tvm.contrib.download import download_testdata
import keras
import tensorflow as tf
import numpy as np

加载预训练的 Keras 模型

加载 Keras 提供的预训练 resnet-50 分类模型:

if tuple(keras.__version__.split(".")) < ("2", "4", "0"):
    weights_url = "".join(
        [
            "https://github.com/fchollet/deep-learning-models/releases/",
            "download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5",
        ]
    )
    weights_file = "resnet50_keras_old.h5"
else:
    weights_url = "".join(
        [
            " https://storage.googleapis.com/tensorflow/keras-applications/",
            "resnet/resnet50_weights_tf_dim_ordering_tf_kernels.h5",
        ]
    )
    weights_file = "resnet50_keras_new.h5"

weights_path = download_testdata(weights_url, weights_file, module="keras")
keras_resnet50 = tf.keras.applications.resnet50.ResNet50(
    include_top=True, weights=None, input_shape=(224, 224, 3), classes=1000
)
keras_resnet50.load_weights(weights_path)

加载测试图像

这里使用的还是先前猫咪的图像:

from PIL import Image
from matplotlib import pyplot as plt
from tensorflow.keras.applications.resnet50 import preprocess_input

img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))
plt.imshow(img)
plt.show()
# 预处理输入
data = np.array(img)[np.newaxis, :].astype("float32")
data = preprocess_input(data).transpose([0, 3, 1, 2])
print("input_1", data.shape)

请添加图片描述

输出结果:

input_1 (1, 3, 224, 224)

使用 Relay 编译模型

将 Keras 模型(NHWC 布局)转换为 Relay 格式(NCHW 布局):

shape_dict = {"input_1": data.shape}
mod, params = relay.frontend.from_keras(keras_resnet50, shape_dict)
# 编译模型
target = "cuda"
dev = tvm.cuda(0)

# TODO(mbs):opt_level=3 导致 nn.contrib_conv2d_winograd_weight_transform
# 很可能由于潜在的错误,最终出现在 cuda 上的内存验证失败的模块中。
# 注意:只能在 evaluate() 中传递 context,它不被 create_executor() 捕获。
with tvm.transform.PassContext(opt_level=0):
    model = relay.build_module.create_executor("graph", mod, dev, target, param).evaluate()

在 TVM 上执行

dtype = "float32"
tvm_out = model(tvm.nd.array(data.astype(dtype)))
top1_tvm = np.argmax(tvm_out.numpy()[0])

查找分类集名称

在 1000 个类的分类集中,查找分数最高的第一个:

synset_url = "".join(
    [
        "https://gist.githubusercontent.com/zhreshold/",
        "4d0b62f3d01426887599d4f7ede23ee5/raw/",
        "596b27d23537e5a1b5751d2b0481ef172f58b539/",
        "imagenet1000_clsid_to_human.txt",
    ]
)
synset_name = "imagenet1000_clsid_to_human.txt"
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
    synset = eval(f.read())
print("Relay top-1 id: {}, class name: {}".format(top1_tvm, synset[top1_tvm]))
# 验证 Keras 输出的正确性
keras_out = keras_resnet50.predict(data.transpose([0, 2, 3, 1]))
top1_keras = np.argmax(keras_out)
print("Keras top-1 id: {}, class name: {}".format(top1_keras, synset[top1_keras]))

输出结果:

Relay top-1 id: 285, class name: Egyptian cat
Keras top-1 id: 285, class name: Egyptian cat

下载 Python 源代码:from_keras.py

下载 Jupyter Notebook:from_keras.ipynb

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

编译 Keras 模型 的相关文章

随机推荐

  • ubuntu20.04防火墙相关命令整理

    1 查看防火墙状态 sudo ufw status 2 开启防火墙 sudo ufw enable 3 关闭防火墙 sudo ufw disable 4 重启防火墙 sudo ufw reload 4 开启指定端口 sudo ufw all
  • pytorch踩坑日记

    昨天使用pytorch写一个程序 程序写完之后却一直不能正确运行 今天定位到了代码的问题所在 我的代码其中有一处逻辑是这样的 get a 这里的a就是我想反向求导更新的参数 b torch nonzero a 得到a里面所有不为0的下标 f
  • Spring自定义注解定义AOP配置去xml

    原理参考ImportBeanDefinitionRegistrar SPI简化Spring开发 spring中AOP使用非常广泛 引入方式一般分为两种 注解方式或xml方式 直接方式使用 AspectJ这样的注解 其缺点是需要手写切面实现业
  • 机器学习笔记--1.6数据可视化

    1 表与线性结构的可视化 Python提供四种容器结构 list dict set tuple来装载数据 其中线性结构有两种 list和tuple 由于tuple是只读结构 仅用于外部生成器生成的数据 所以最常用的线性结构就是list im
  • Ldap简单介绍(转)

    注 文章内容转载 觉得对ldap初次接触的你我非常的实用 关于LDAP的概念随便网上有很多 我不想重复 这里只是说一下我自己的 理解 都说它是 轻量级目录协议 太专业 我不懂 我只把它想象成 简单 的 目录协议 几个很重要的概念 以后会用到
  • Linux下MySQL安装

    MySQL安装 过程 下载官方包 wget i c http dev mysql com get mysql57 community release el7 10 noarch rpm 成功信息 FINISHED 2023 03 20 09
  • php预览md文件,用HTML+CSS做一个实时预览的markdown编辑器

    这次给大家带来用HTML CSS做一个实时预览的markdown编辑器 用HTML CSS做一个实时预览的markdown编辑器的注意事项有哪些 下面就是实战案例 一起来看一下 第一步 搭建布局 1 构思布局 以下是总体布局 2 项目下新建
  • 在cmd命令下启动软件

    1 配置jdk 1 找到jdk的安装路径 点开到bin目录下 复制这个目录 如下图 2 我的电脑 右键属性 高级系统设置 环境变量 双击 如下图 3 系统变量 path 双击 如下图 4 粘贴上面复制的路径到变量值最前面 末尾以英文的逗号结
  • TypeScript反射机制动态创建类

    前言 在前一篇文章桥接模式与策略模式的区别与刘伟老师的桥接模式中 我们可以明白桥接模式处理得比较好的一个点是在于Java的反射机制 那么 假如我们需要再TypeScript中 来实现桥接模式的处理 需要怎么样来实现这个 反射 呢 注 在策略
  • 【计算机毕业设计】045新闻推荐系统

    一 系统截图 需要演示视频可以私聊 摘要 随着信息互联网购物的飞速发展 国内放开了自媒体的政策 一般企业都开始开发属于自己内容分发平台的网站 本文介绍了新闻推荐系统的开发全过程 通过分析企业对于新闻推荐系统的需求 创建了一个计算机管理新闻推
  • Python3 requests_htm 设置代理

    简介 Python上有一个非常著名的HTTP库 requests 相比大家都听说过 用过的人都说好 现在requests库的作者又发布了一个新库 叫做requests html 看名字也能猜出来 这是一个解析HTML的库 而且用起来和req
  • QT入门Containers之QGroupBox、QDockWidget

    目录 一 QGroupBox界面相关 1 布局介绍 二 QDockWidget的介绍 1 去除标题栏 2 设置垂直属性 3 代码测试下 三 Demo展示 此文为作者原创 创作不易 转载请标明出处 一 QGroupBox界面相关 1 布局介绍
  • 在Tomcat中部署war包,404

    用IDEA中的mevan插件打包后 放在服务器中 访问404 放war包的位置没有问题 端口也开放了 就是访问不到 解决方法为 在启动类上继承 SpringBootServletInitializer 然后重写config方法 再次打包后
  • Android 9 (P)非SDK API限制调用开发指南

    Android 9 P 非SDK API限制调用开发指南 Android 9 P 开发适配指南系列博客目录 Adnroid 9 P recovery升级Map of cache recovery block map failed问题分析指南
  • anaconda中安装mysql

    官网下载安装包mysql官网下载地址 下载后进行解压 解压到本地文档之中 入目录之中新建 ini文件 新建文本文件 编辑完成之后 另存为选择全部文件 my ini文件之中内容如下 mysql 设置mysql客户端默认字符集 default
  • WPS如何使用VBA

    WPS专业版可以使用VBA的 非专业版没有测试过 不清楚 WPS专业版如何安装VBA呢 下载这个包 WPS目前已支持VBA 7 1版本 VBA For WPS 2019 zip 按顺序1 2 3 4即可 WPS 11 8 2 12011测试
  • Mac移动硬盘分区无法装载

    https blog csdn net tyforfreedom article details 48092901
  • win10安装linux子系统详细教程(非虚拟机方式)

    文章目录 1 前言 2 安装Windows Terminal 3 开启Windows子系统功能 4 安装Centos子系统 5 使用Centos子系统 1 前言 对于程序员来说 Linux技能基本是必备技能了 通常操作Linux有两种情况
  • vue/iview的table单元格可编辑,可上下键切换,小键盘enter可选中下一个

    在开发过程中 前后至今遇到好几次的编辑输入框编辑情况 4 24之前的版本 虽然改进好几次操作 但是都是用render函数实现 发现有时并不是很好操作 而且隐藏好几个bug 今天 2019 7 19 发布一个新版 目前无发现bug 而且监听键
  • 编译 Keras 模型

    本篇文章译自英文文档 Compile Keras Models 作者是 Yuwei Hu 更多 TVM 中文文档可访问 TVM 中文站 本文介绍如何用 Relay 部署 Keras 模型 首先安装 Keras 和 TensorFlow 可通