无法保存自定义子类模型

2024-06-26

灵感来自tf.keras.Model 子类化 https://www.tensorflow.org/guide/keras#model_subclassing我创建了自定义模型。
我可以训练它并获得成功的结果,但是我无法保存它.
我使用 python3.6 和tensorflow v1.10(或v1.9)

最小完整代码示例在这里:

import tensorflow as tf
from tensorflow.keras.datasets import mnist


class Classifier(tf.keras.Model):
    def __init__(self):
        super().__init__(name="custom_model")

        self.batch_norm1 = tf.layers.BatchNormalization()
        self.conv1 = tf.layers.Conv2D(32, (7, 7))
        self.pool1 = tf.layers.MaxPooling2D((2, 2), (2, 2))

        self.batch_norm2 = tf.layers.BatchNormalization()
        self.conv2 = tf.layers.Conv2D(64, (5, 5))
        self.pool2 = tf.layers.MaxPooling2D((2, 2), (2, 2))

    def call(self, inputs, training=None, mask=None):
        x = self.batch_norm1(inputs)
        x = self.conv1(x)
        x = tf.nn.relu(x)
        x = self.pool1(x)

        x = self.batch_norm2(x)
        x = self.conv2(x)
        x = tf.nn.relu(x)
        x = self.pool2(x)

        return x


if __name__ == '__main__':
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    x_train = x_train.reshape(*x_train.shape, 1)[:1000]
    y_train = y_train.reshape(*y_train.shape, 1)[:1000]

    x_test = x_test.reshape(*x_test.shape, 1)
    y_test = y_test.reshape(*y_test.shape, 1)

    y_train = tf.keras.utils.to_categorical(y_train)
    y_test = tf.keras.utils.to_categorical(y_test)

    model = Classifier()

    inputs = tf.keras.Input((28, 28, 1))

    x = model(inputs)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(10, activation="sigmoid")(x)

    model = tf.keras.Model(inputs=inputs, outputs=x)
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
    model.fit(x_train, y_train, epochs=1, shuffle=True)

    model.save("./my_model")

错误信息:

1000/1000 [==============================] - 1s 1ms/step - loss: 4.6037 - acc: 0.7025
Traceback (most recent call last):
  File "/home/user/Data/test/python/mnist/mnist_run.py", line 62, in <module>
    model.save("./my_model")
  File "/home/user/miniconda3/envs/ml3.6/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 1278, in save
    save_model(self, filepath, overwrite, include_optimizer)
  File "/home/user/miniconda3/envs/ml3.6/lib/python3.6/site-packages/tensorflow/python/keras/engine/saving.py", line 101, in save_model
    'config': model.get_config()
  File "/home/user/miniconda3/envs/ml3.6/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 1049, in get_config
    layer_config = layer.get_config()
  File "/home/user/miniconda3/envs/ml3.6/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 1028, in get_config
    raise NotImplementedError
NotImplementedError

Process finished with exit code 1

我查看了错误行并发现获取配置方法检查self._is_graph_network

有人处理这个问题吗?

Thanks!

更新1:
在 keras 2.2.2 上(不是 tf.keras)
找到评论(用于模型保存)
文件:keras/engine/network.py
功能:获取配置

# 子类网络不可序列化
#(除非序列化是通过
# 子类网络的作者)。

所以,显然这是行不通的...
我想知道,他们为什么不在书中指出这一点文档 https://www.tensorflow.org/guide/keras(例如:“使用子类化而无法保存!”)

更新2:
在发现keras文档 https://keras.io/models/about-keras-models/:

在子类化模型中,模型的拓扑被定义为 Python 代码
(而不是作为层的静态图)。这意味着模型的
无法检查或序列化拓扑。结果,出现以下情况
方法和属性不可用于子类模型:

模型.输入和模型.输出。
model.to_yaml() 和 model.to_json()
model.get_config() 和 model.save()。

因此,无法通过使用子类化来保存模型。
可以只使用Model.save_weights()


TensorFlow 2.2

感谢 @cal 注意到我新的 TensorFlow 支持保存自定义模型!

通过使用 model.save 保存整个模型,并使用 load_model 恢复以前存储的子类模型。以下代码片段描述了如何实现它们。

class ThreeLayerMLP(keras.Model):

  def __init__(self, name=None):
    super(ThreeLayerMLP, self).__init__(name=name)
    self.dense_1 = layers.Dense(64, activation='relu', name='dense_1')
    self.dense_2 = layers.Dense(64, activation='relu', name='dense_2')
    self.pred_layer = layers.Dense(10, name='predictions')

  def call(self, inputs):
    x = self.dense_1(inputs)
    x = self.dense_2(x)
    return self.pred_layer(x)

def get_model():
  return ThreeLayerMLP(name='3_layer_mlp')

model = get_model()
# Save the model
model.save('path_to_my_model',save_format='tf')

# Recreate the exact same model purely from the file
new_model = keras.models.load_model('path_to_my_model')

See: 使用 Keras 保存和序列化模型 - 第二部分:保存和加载子类模型 https://www.tensorflow.org/guide/keras/save_and_serialize#part_ii_saving_and_loading_of_subclassed_models

TensorFlow 2.0

TL;DR:

  1. 不使用model.save()用于自定义子类 keras 模型;
  2. use save_weights() and load_weights()反而。

在 Tensorflow 团队的帮助下,事实证明保存自定义子类 Keras 模型的最佳实践是保存其权重并在需要时加载回来。

我们不能简单地保存 Keras 自定义子类模型的原因是它包含自定义代码,无法安全地序列化。但是,当我们具有相同的模型结构和自定义代码时,可以毫无问题地保存/加载权重。

Keras 的作者 Francois Chollet 编写了一篇很棒的教程,介绍了如何在 Colab 中的 Tensorflow 2.0 中保存/加载顺序/函数/Keras/自定义子类模型,网址为here https://colab.research.google.com/drive/172D4jishSgE3N7AO6U2OKAA_0wNnrMOq#scrollTo=mJqOn0snzCRy. In 保存子类模型部分,它说:

顺序模型和功能模型是表示 DAG 层的数据结构。因此,它们可以安全地序列化和反序列化。

子类模型的不同之处在于它不是数据结构,而是 一段代码。模型的架构是通过主体定义的 的调用方法。这意味着模型的架构 无法安全地序列化。要加载模型,您需要有 访问创建它的代码(模型子类的代码)。 或者,您可以将此代码序列化为字节码(例如 通过酸洗),但这是不安全的,而且通常不可移植。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

无法保存自定义子类模型 的相关文章

随机推荐

  • Ansible:findall 正则表达式中变量的正确语法是什么

    我使用的是 Ansible 版本 2 9 我想做一个 GET 它返回一个信息块 从该信息中正则表达式一个 ID 该 ID 对应于我目前正在迭代的任何主机 然后使用该 ID 执行操作 我有正则表达式工作 https regex101 com
  • 从 pdf 和 word 文件中提取文本

    如何在 C 中从 pdf 或 word 文件中提取文本 删除粗体 图像和其他富文本格式媒体 您可以使用专为索引服务设计 由索引服务使用的过滤器 它们旨在从各种文档中提取纯文本 这对于在文档内部进行搜索非常有用 您可以将其用于 Office
  • \ufeff 标识符中的无效字符

    我有以下代码 import urllib request try url https www google com search q test headers usag Mozilla 5 0 Macintosh Intel Mac OS
  • 通过ELB访问AWS EC2实例

    我试图在弹性负载均衡器下设置两个实例 但无法弄清楚应该如何通过负载均衡器访问这些实例 我已经使用安全组设置了实例 以允许从任何地方访问某些端口 我可以使用 公共 DNS publicdns 主机名和端口 PORT 直接访问实例 http p
  • Asp.Net MVC3 - 如何创建动态 DropDownList

    我发现了很多关于此的文章 但我仍然不知道到底如何做到这一点 我正在尝试创建自己的博客引擎 我有用于创建文章的视图 我首先使用 EF 和代码 现在我必须填写应添加文章的类别数量 但我想将其更改为下拉列表 名称为类别 我的模型看起来是这样的 p
  • Hibernate 时间戳 - 毫秒精度

    似乎以毫秒精度存储时间戳是休眠的一个已知问题 我在数据库中的字段最初设置为时间戳 3 但我也尝试过日期时间 3 不幸的是 它没有任何区别 我尝试过使用 Timestamp 和 Date 类 最近我开始使用 joda time 库 经过所有这
  • Swift 3:将 UIButton 扩展添加到 ViewController

    我是 iOS Swift 的初学者 尝试创建一个没有 Storyboard 的简单应用程序 我创建了一个UIButton扩展名 我想在我的视图中添加一个简单的按钮 稍后将设置约束 不幸的是 该按钮不可见 如果有人帮助我 我将不胜感激 谢谢你
  • 查看tomcat服务器的连接数

    我在 Tomcat Server 5 5 17 上部署了一个 Java Java EE Web 应用程序 我想知道连接到服务器的客户端数量 我们怎样才能找到它呢 最可靠的方法是搜索ip addr of srv port in netstat
  • 通过 HTTP 代理进行 iOS XMPP 聊天

    我有一个 iPhone 应用程序 可与 2 项服务配合使用 通过 http 使用 REST 服务 使用 AFNetworking 通过 TCP 进行 XMPP 聊天 使用 XMPPFrameworkhttps github com robb
  • 具有 dropdown-menu-right 类的下拉菜单未向右对齐

    我有以下导航栏结构 current user username 来自我的模板系统 ul class navbar nav mr auto mt 2 mt lg 0 ul div class dropdown show a class dro
  • JS - 文件读取器 API 获取图像文件大小和尺寸

    您好 我正在使用以下代码来使用文件读取器 API 获取上传图像
  • 如何在 Quill(富文本编辑器)中检测和修剪前导/尾随空格?

    如何检测并删除前导 尾随空格Quill https quilljs com 哪个是富文本编辑器 例如 样本HTML下面代表Quill文本的输出 nHi 我们想要检测并删除由以下命令创建的每个文本块的前导和尾随空格Quill 不适用于整个文档
  • 如何制作像 Twitter 一样带有字符限制突出显示的文本区域?

    Twitter 的提交推文文本框会突出显示超出字符限制的字符 如您所见 超出字符限制的字符以红色突出显示 我怎样才能实现这样的目标 您将在这里找到必要的解决方案和所需的代码 超过 140 限制 即变为负数 时如何插入 标签 https st
  • 访问 2010 DLookUp

    第一次使用 MS Access 遇到了一些问题 如果有人可以指出我正确的方向 所以我正在做一个模拟数据库 所以它看起来很傻 只是为了了解细节 目前需要一些有关 DLookUp 的帮助 我的数据库有两个表 具有以下字段 C ID课程PK 学生
  • 预期在模拟中调用一次,但使用 Moq 时调用次数为 0 次

    我收到错误 在mock上调用一次 但是0次 下面是我的代码结构 public class GenerateAddress IGenerateAddress public GenerateAddress IAddress createAdd
  • 使用 MongoDB PHP 驱动程序时的安全问题

    我有在 MYSQL 上保护 sql 注入的经验 但是在使用 php 驱动程序的 MongoDB 上我应该注意什么 在大多数页面中 我通过 GET POST 和搜索 插入系统获取数据 我通过 UDID 其他字段进行搜索 并且可以插入任何字符串
  • 在android上获取电池温度

    android 如何获取电池的温度 http developer android com reference android os BatteryManager html http developer android com referen
  • 如何使用 CMake 链接多个库

    我有一些与 DCMTK 相关的代码 如果我从命令行使用 g 我可以成功构建并运行它 这是代码 include dcmtk config osconfig h include dcmtk dcmdata dctk h int main Dcm
  • 基于 .NET 4 构建的 MEF 应用程序是否可以导入针对 .NET 3.5 构建的类型?

    我正在使用托管扩展性框架开发一个主机应用程序 它是针对 NET 4 和框架中内置的 System ComponentModel Composition 程序集构建的 我希望支持使用 NET 3 5 开发部件并以声明方式导出它们的能力 由于导
  • 无法保存自定义子类模型

    灵感来自tf keras Model 子类化 https www tensorflow org guide keras model subclassing我创建了自定义模型 我可以训练它并获得成功的结果 但是我无法保存它 我使用 pytho