keras 模型子类化示例

2024-01-24

从 Keras 2.2.0 开始,发布了模型定义的第 3 个 API:模型子类化。

根据常见问题解答:

然而,在子类模型中,模型的拓扑定义为 Python 代码(而不是静态的层图)。这意味着 无法检查或序列化模型的拓扑。结果, 以下方法和属性不可用于子类 楷模:

模型.输入和模型.输出。 model.to_yaml() 和 model.to_json() model.get_config() 和 model.save()。

保存经过训练的模型以进行推理的唯一选择是使用模型.save_weights方法。但是,我没有运气将模型加载回来进行推理。遇到的错误消息包括:

该模型从未被调用,因此尚未创建其权重,因此无法显示摘要。首先构建模型(例如,通过在一些测试数据上调用它)。 您正在尝试将包含 4 层的权重文件加载到具有 0 层的模型中。 未实现错误

谁能给出一个完整的示例来创建子类 keras 模型、训练和 save_weights,然后将其加载回来进行推理?


在尝试保存子类模型权重之前,您需要调用 tf.keras.Model.build 方法。另一种方法是在尝试保存模型权重之前对某些输入调用 tf.keras.Model.fit 或 tf.keras.Model.fit.call。这同样适用于将权重加载到新创建的子类模型实例中。在尝试加载权重之前,您需要调用上述方法之一。 以下示例显示了子类模型的保存和加载权重

import tensorflow as tf

print('TensorFlow', tf.__version__)

class ResidualBlock(tf.keras.Model):
    def __init__(self, block_type=None, n_filters=None):
        super(ResidualBlock, self).__init__()
        self.n_filters = n_filters
        if block_type == 'identity':
            self.strides = 1
        elif block_type == 'conv':
            self.strides = 2
            self.conv_shorcut = tf.keras.layers.Conv2D(filters=self.n_filters, 
                               kernel_size=1, 
                               padding='same',
                               strides=self.strides,
                               kernel_initializer='he_normal')
            self.bn_shortcut = tf.keras.layers.BatchNormalization(momentum=0.9)

        self.conv_1 = tf.keras.layers.Conv2D(filters=self.n_filters, 
                               kernel_size=3, 
                               padding='same',
                               strides=self.strides,
                               kernel_initializer='he_normal')
        self.bn_1 = tf.keras.layers.BatchNormalization(momentum=0.9)
        self.relu_1 = tf.keras.layers.ReLU()

        self.conv_2 = tf.keras.layers.Conv2D(filters=self.n_filters, 
                               kernel_size=3, 
                               padding='same', 
                               kernel_initializer='he_normal')
        self.bn_2 = tf.keras.layers.BatchNormalization(momentum=0.9)
        self.relu_2 = tf.keras.layers.ReLU()

    def call(self, x, training=False):
        shortcut = x
        if self.strides == 2:
            shortcut = self.conv_shorcut(x)
            shortcut = self.bn_shortcut(shortcut)
        y = self.conv_1(x)
        y = self.bn_1(y)
        y = self.relu_1(y)
        y = self.conv_2(y)
        y = self.bn_2(y)
        y = tf.add(shortcut, y)
        y = self.relu_2(y)
        return y

class ResNet34(tf.keras.Model):
    def __init__(self, include_top=True, n_classes=1000):
        super(ResNet34, self).__init__()

        self.n_classes = n_classes
        self.include_top = include_top
        self.conv_1 = tf.keras.layers.Conv2D(filters=64, 
                                               kernel_size=7, 
                                               padding='same', 
                                               strides=2, 
                                               kernel_initializer='he_normal')
        self.bn_1 = tf.keras.layers.BatchNormalization(momentum=0.9)
        self.relu_1 = tf.keras.layers.ReLU()
        self.maxpool = tf.keras.layers.MaxPool2D(3, 2, padding='same')
        self.residual_blocks = tf.keras.Sequential()
        for n_filters, reps, downscale in zip([64, 128, 256, 512], 
                                              [3, 4, 6, 3], 
                                              [False, True, True, True]):
            for i in range(reps):
                if i == 0 and downscale:
                    self.residual_blocks.add(ResidualBlock(block_type='conv', 
                                                              n_filters=n_filters))
                else:
                    self.residual_blocks.add(ResidualBlock(block_type='identity', 
                                                              n_filters=n_filters))
        self.GAP = tf.keras.layers.GlobalAveragePooling2D()
        self.fc = tf.keras.layers.Dense(units=self.n_classes)

    def call(self, x, training=False):
        y = self.conv_1(x)
        y = self.bn_1(y)
        y = self.relu_1(y)
        y = self.maxpool(y)
        y = self.residual_blocks(y)
        if self.include_top:
            y = self.GAP(y)
            y = self.fc(y)
        return y

## saving weights
model = ResNet34()
model.build((1, 224, 224, 3))
model.summary()
model.save_weights('model_weights.h5')

## loading saved weights
model_new = ResNet34()
model_new.build((1, 224, 224, 3))
model_new.load_weights('model_weights.h5')

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

keras 模型子类化示例 的相关文章

  • Native TF 与 Keras TF 性能比较

    我使用本机和后端张量流创建了完全相同的网络 但在使用多个不同参数进行了多个小时的测试后 仍然无法弄清楚为什么 keras 优于本机张量流并产生更好 稍微但更好 的结果 Keras 是否实现了不同的权重初始化方法 或者执行除 tf train
  • 批量归一化,是还是否?

    我使用 Tensorflow 1 14 0 和 Keras 2 2 4 以下代码实现了一个简单的神经网络 import numpy as np np random seed 1 import random random seed 2 imp
  • Keras ImageDataGenerator 相当于 csv 文件

    我在文件夹中排序了一堆数据 如下图所示 我需要构建一个 DataIterator 以便将数据放入神经网络模型中 当数据是图像时 我找到了很多例子来解决这个问题 使用 Keras 类图像数据生成器及其方法流自目录 但当数据是 csv 结构时则
  • NotImplementedError:尚未为未构建的模型子类启用“fit_generator”

    我正在使用以下代码 import tensorflow as tf traindata tf keras preprocessing image ImageDataGenerator rescale 1 255 shear range 0
  • 在不同的 GPU 上同时训练多个 keras/tensorflow 模型

    我想在 Jupyter Notebook 中同时在多个 GPU 上训练多个模型 我正在使用 4GPU 的节点上工作 我想将一个 GPU 分配给一个模型并同时训练 4 个不同的模型 现在 我通过 例如 为一台笔记本选择 GPU import
  • 将 Pytorch LSTM 的状态参数转换为 Keras LSTM

    我试图将现有的经过训练的 PyTorch 模型移植到 Keras 中 在移植过程中 我陷入了LSTM层 LSTM 网络的 Keras 实现似乎具有三种状态类型的状态矩阵 而 Pytorch 实现则具有四种状态矩阵 例如 对于hidden l
  • scikit-learn 和tensorflow 有什么区别?可以一起使用它们吗?

    对于这个问题我无法得到满意的答案 据我了解 TensorFlow是一个数值计算库 经常用于深度学习应用 而Scikit learn是一个通用机器学习框架 但它们之间的确切区别是什么 TensorFlow 的目的和功能是什么 我可以一起使用它
  • 在 GPU 支持下对高维数据进行更快的 Kmeans 聚类

    我们一直在使用 Kmeans 来对日志进行聚类 典型的数据集有 10 mill 具有 100k 特征的样本 为了找到最佳 k 我们并行运行多个 Kmeans 并选择轮廓得分最佳的一个 在 90 的情况下 我们最终得到的 k 介于 2 到 1
  • ValueError:请使用“Layer”实例初始化“TimeDistributed”层

    我正在尝试构建一个可以在音频和视频样本上进行训练的模型 但出现此错误ValueError Please initialize TimeDistributed layer with a Layer instance You passed Te
  • 如何在 keras 中添加可训练的 hadamard 产品层?

    我试图在训练样本中引入稀疏性 我的数据矩阵的大小为 比如说 NxP 我想将其传递到一个层 keras 层 该层的权重大小与输入大小相同 即可训练权重矩阵W的形状为NxP 我想对这一层的输入矩阵进行哈达玛乘积 逐元素乘法 W 按元素与输入相乘
  • Tensorflow中通过字符串选择不同的模式

    我正在尝试构建一个 VAE 网络 我希望模型在不同的模式下做不同的事情 我有三种模式 训练 相同 和 不同 以及一个名为 interpolation mode 的函数 它根据模式执行不同的操作 我的代码如下所示 import tensorf
  • 阻止 TensorFlow 访问 GPU? [复制]

    这个问题在这里已经有答案了 有没有一种方法可以纯粹在CPU上运行TensorFlow 我机器上的所有内存都被运行 TensorFlow 的单独进程占用 我尝试将 per process memory fraction 设置为 0 但未成功
  • Tensorflow 中的自定义资源

    由于某些原因 我需要为 Tensorflow 实现自定义资源 我试图从查找表实现中获得灵感 如果我理解得好的话 我需要实现3个TF操作 创建我的资源 资源的初始化 例如 在查找表的情况下填充哈希表 执行查找 查找 查询步骤 为了促进实施 我
  • GradientTape 根据损失函数是否被 tf.function 修饰给出不同的梯度

    我发现计算的梯度取决于 tf function 装饰器的相互作用 如下所示 首先 我为二元分类创建一些合成数据 tf random set seed 42 np random seed 42 x tf random normal 2 1 y
  • pip:需要将包名称tensorflow-gpu更改为tensorflow

    我正在尝试将具有 GPU 支持的张量流安装到 conda 环境中 我使用命令 pip install ignore installed upgrade https storage googleapis com tensorflow linu
  • 张量流服务错误:参数无效:JSON 对象:没有命名输入

    我正在尝试使用 Amazon Sagemaker 训练模型 并且希望使用 Tensorflow 服务来为其提供服务 为了实现这一目标 我将模型下载到 Tensorflow 服务 docker 并尝试从那里提供服务 Sagemaker 的训练
  • 如何将神经网络的输出限制在特定范围内?

    我正在使用 Keras 进行回归任务 并希望将输出限制在一个范围内 例如 1 到 10 之间 有没有办法保证这一点 像这样编写自定义激活函数 a simple custom activation from keras import back
  • 在 Keras 中连接两个目录迭代器

    假设我有类似以下内容 image data generator ImageDataGenerator rescale 1 255 train generator image data generator flow from director
  • 错误:分配具有形状的张量时出现 OOM

    在使用 Apache JMeter 进行性能测试期间 我面临着初始模型的问题 错误 分配形状为 800 1280 3 和类型的张量时出现 OOM 通过分配器浮动在 job localhost replica 0 task 0 device
  • 使用队列从多个输入文件中统一采样

    我的数据集中的每个类都有一个序列化文件 我想使用队列来加载每个文件 然后将它们放入 RandomShuffleQueue 中 这样我就可以从每个类中获得随机的示例组合 我认为这段代码会起作用 在此示例中 每个文件有 10 个示例 filen

随机推荐

  • Mongoengine PointField 给出了预期的位置对象,位置数组格式不正确错误

    我有一个模型如下 class Station Document location PointField 尝试按如下方式写入数据 station Station station location type Point coordinates
  • 与多个项目和开发人员签署程序集的最佳实践

    我正在寻找在拥有 30 多个开发人员 20 多个解决方案和 60 多个项目的组织中应用签名程序集的建议和最佳实践 我们使用 Visual Studio Team System 2008 和 TFS 虽然创建密钥和签署程序集是一个非常简单且直
  • SQL Server 2000 实时数据镜像

    我目前正在使用 2 个 sql 2000 服务器 其中一个可以查询 但不能添加任何数据库 这导致第二个服务器有很多查询 这些查询使用第一个服务器作为链接服务器 我想在查询实时数据的同时提高性能 是否可以将实时数据镜像到第二台服务器 这样查询
  • AWS ACM SSL 协议错误

    我正在使用 AWS EC2 实例 亚马逊 Linux 弹性 IP 尝试通过 ACM 设置 SSL 证书已验证 负载均衡器正在通过健康检查 侦听 prot 443 转发到端口 80 最初 在测试 https 时 我收到连接被拒绝的消息 这让我
  • 为什么需要 virtualenv?

    我是 Python 初学者 I read virtualenvPython项目开发时首选 我根本无法理解这一点 为什么是virtualenv首选 虚拟环境 http virtualenv readthedocs org en latest
  • 如何添加迄今为止的天数(作为列的值)?

    我在 Spark 中向日期格式列添加天数 数字 时遇到问题 我知道有一个功能date add它有两个参数 日期列和整数 date add date startdate tinyint smallint int days 我想使用整数类型的列
  • Siri 快捷方式 iOS 13 错误 INUIAddVoiceShortcutButton

    在我的项目中 我使用 Siri 快捷方式INUIAddVoiceShortcutButton 我使用这种方法来创建按钮并关联NSUserActivity let button INUIAddVoiceShortcutButton style
  • 如何在不使用 tabindex 的情况下进行 Tab 键切换时跳过项目?

    在 javascript onfocus 处理程序中 是否有一种好方法可以将焦点转移到 Tab 键顺序中的下一个项目 而无需手动输入下一个项目的 ID 我在 Django jQuery 中构建了一个 HTML 日期选择器 这是一个行编辑 然
  • 发送不带接受/拒绝选项的 Outlook 会议请求

    我正在使用我的 NET 程序发送 Outlook 会议请求 使用以下内容作为源 在没有 Outlook 的情况下发送 Outlook 会议请求 https stackoverflow com questions 461889 sending
  • 如何制作 GUI?

    我为 Nintendo DS 制作了 GUI 系统的许多不同的独立部分 例如按钮 文本框和选择框 但我需要一种将这些类包含在一个 Gui 类中的方法 以便我可以将所有内容都绘制到屏幕上一次 并立即检查所有按钮以检查是否有任何按钮被按下 我的
  • 是否可以将批量 FFT 与 CUDA 的 cuFFT 库和 cufftPlanMany 重叠?

    我正在尝试并行化称为 Chromaprint 的声学指纹识别库的 FFT 变换 它的工作原理是 将原始音频分割成许多重叠的帧并对它们应用傅立叶变换 Chromaprint 使用 4096 的帧大小 2 3 重叠 例如 第一帧由元素 0 40
  • 可以用 Electron 进行复制/粘贴吗?

    我正在使用 Electron Nightmare js 进行单元测试 我需要复制一个string到 clibboard gt 聚焦某个元素 gt 粘贴内容 然后测试是关于我的 JavaScript 是否正常处理 我在电子文档中读到剪贴板 A
  • R 中使用 mapply 对子集参数进行非标准评估

    我无法使用subset的论证xtabs or aggregate 或我测试过的任何功能 包括ftable and lm with mapply 以下调用失败并显示subset争论 但它们的工作没有 mapply FUN xtabs form
  • 将 pyQt UI 转换为 python

    有没有一种方法可以将使用 qtDesigner 形成的 ui 转换为 python 版本来使用 而无需额外的文件 我在这个 UI 中使用 Maya 并且将此 UI 文件转换为可读的 Python 版本来实现 这真的很棒 您可以使用pyuic
  • 获取每组最新的n条记录

    假设我有下表 id coulmn id value date 1 10 a 2016 04 01 1 11 b 2015 10 02 1 12 a 2016 07 03 1 13 a 2015 11 11 2 11 c 2016 01 10
  • java.lang.RuntimeException:无法实例化服务

    我正在尝试编写一个监视短信的应用程序 我想从我的主 Activity 类启动一个服务 但该服务由于某种原因没有启动 我认为我在清单文件中声明服务的方式或从活动中调用它的方式可能存在问题 这是我的活动代码的一部分 public class T
  • 如何从 Log4j Logger / Appender 中排除单个类?

    我有一个包 com example 这个包有五个类 我想将其中四个类记录到一个文件中 但排除第五个类 我可以写四个记录器 例如logger name com example Class1 并将相同的附加程序添加到所有四个记录器 有没有更简单
  • 如何对非托管 C++ Dll 进行强命名?

    我正在开发一个 C 应用程序 它使用EasyHook 库 http easyhook codeplex com 用于 DLL 注入 EasyHook 要求任何使用它的应用程序都必须是强命名的 为了对应用程序进行强命名 我需要确保我使用的所有
  • 立即调用函数表达式 (IIFE) 相对于普通函数的优势

    我对 javascript 很陌生 我读过模块模式 https addyosmani com resources essentialjsdesignpatterns book modulepatternjavascript提供某种名称空间并
  • keras 模型子类化示例

    从 Keras 2 2 0 开始 发布了模型定义的第 3 个 API 模型子类化 根据常见问题解答 然而 在子类模型中 模型的拓扑定义为 Python 代码 而不是静态的层图 这意味着 无法检查或序列化模型的拓扑 结果 以下方法和属性不可用