如何从 .h5 文件正确加载带有自定义层的 Keras 模型?

2023-11-24

我构建了一个带有自定义层的 Keras 模型,并将其保存到.h5通过回调文件ModelCheckPoint。 当我在训练后尝试加载该模型时,出现以下错误消息:

__init__() missing 1 required positional argument: 'pool_size'

这是自定义层及其的定义__init__ method:

class MyMeanPooling(Layer):
    def __init__(self, pool_size, axis=1, **kwargs):
        self.supports_masking = True
        self.pool_size = pool_size
        self.axis = axis
        self.y_shape = None
        self.y_mask = None
        super(MyMeanPooling, self).__init__(**kwargs)

这就是我将此层添加到模型中的方法:

x = MyMeanPooling(globalvars.pool_size)(x)

这就是我加载模型的方式:

from keras.models import load_model

model = load_model(model_path, custom_objects={'MyMeanPooling': MyMeanPooling})

这些是完整的错误消息:

Traceback (most recent call last):
  File "D:/My Projects/Attention_BLSTM/script3.py", line 9, in <module>
    model = load_model(model_path, custom_objects={'MyMeanPooling': MyMeanPooling})
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 419, in load_model
    model = _deserialize_model(f, custom_objects, compile)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 225, in _deserialize_model
    model = model_from_config(model_config, custom_objects=custom_objects)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 458, in model_from_config
    return deserialize(config, custom_objects=custom_objects)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\layers\__init__.py", line 55, in deserialize
    printable_module_name='layer')
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\utils\generic_utils.py", line 145, in deserialize_keras_object
    list(custom_objects.items())))
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\network.py", line 1022, in from_config
    process_layer(layer_data)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\network.py", line 1008, in process_layer
    custom_objects=custom_objects)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\layers\__init__.py", line 55, in deserialize
    printable_module_name='layer')
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\utils\generic_utils.py", line 147, in deserialize_keras_object
    return cls.from_config(config['config'])
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\base_layer.py", line 1109, in from_config
    return cls(**config)
TypeError: __init__() missing 1 required positional argument: 'pool_size'

实际上我不认为你可以加载这个模型。

最有可能的问题是您没有实施get_config()你的层中的方法。此方法返回应保存的配置值的字典:

def get_config(self):
    config = {'pool_size': self.pool_size,
              'axis': self.axis}
    base_config = super(MyMeanPooling, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

将此方法添加到图层后,您必须重新训练模型,因为之前保存的模型没有保存此图层的配置。这就是为什么你无法加载它,进行此更改后需要重新训练。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何从 .h5 文件正确加载带有自定义层的 Keras 模型? 的相关文章

  • 存储整数列表的最有效方法

    我最近一直在做一个项目 其中一个目标是使用尽可能少的内存来使用 Python 3 存储一系列文件 除了一个整数列表之外 几乎所有文件都占用很少的空间 大致333 000整数长且整数可达约8000在尺寸方面 我目前正在使用pickle存储列表
  • 使用unicode在hdf5中存储字符串数据集

    我试图从包含特殊字符的文件中存储变量字符串表达式 例如 and 这是我的代码 import h5py as h5 file h5 File deleteme hdf5 a dt h5 special dtype vlen str dset
  • 布尔 pandas 之间的操作对称性破缺。具有不等索引的系列

    隐式索引匹配pandas用于不同之间的操作DataFrame Series很棒 而且大多数时候 它都有效 但是 我偶然发现了一个无法按预期工作的示例 import pandas as pd 0 21 0 import numpy as np
  • 如果出现重复,则主键取正值

    我有一个数据框df Key1 Key2 Value K11 K21 V1 K11 K21 V1 K13 K23 V2 K13 K23 V2 现在 例如对于相同的键 K11 K21 组合 我们有 2 个值 一负一正 如何从此 df 中仅获取正
  • 如何获取 sklearn.metrics.classification_report 的输出作为字典?

    我一直在尝试以字典的形式获得分类报告 所以根据 scikit learn 0 20 文档 我这样做 from sklearn import metrics rep metrics classification report y true y
  • pyVISA:以编程方式将仪器返回到本地模式

    我正在使用 pyVISA 来控制 GPIB 网络中的一些仪器 当我创建资源管理器时 GPIB 网络中的所有仪器都会进入远程模式 因此前面板显示被锁定并且不会更新 当我关闭资源管理器时 仪器仍处于远程模式 import visa rm vis
  • python是带有字符串的运算符行为[重复]

    这个问题在这里已经有答案了 我无法理解以下行为 我正在创建 2 个字符串 并使用 is 运算符来比较它 对于第一种情况 它的工作方式有所不同 对于第二种情况 它按预期工作 当我使用逗号或空格时 它显示是什么原因False与比较is当没有使用
  • PyMC3-自定义 theano Op 进行数值积分

    我使用 PyMC3 进行参数估计 使用必须定义的特定似然函数 我用谷歌搜索了一下 发现我应该使用densitydist实现用户定义的似然函数的方法 但它不起作用 如何在 PyMC3 中合并用户定义的似然函数并找出最大 aposteriori
  • mac安装Tensorflow出错

    我正在尝试使用以下说明在 mac 中安装 Tensorflow https www tensorflow org install https www tensorflow org install 但是当我想导入tensorflow时 我总是
  • 在Spyder(Python 3.6)中导入cv2时出现导入错误

    我已经在Windows操作系统中安装了opencv 3 0 0 我已运行该应用程序并已成功将其安装在C 驱动器并还复制了cv2 pyd文件输入C Python27 Lib site packages正如我在几个教程视频中看到的那样 在我的
  • 致命错误:Python.h:没有这样的文件或目录,python-Levenshtein 安装

    首先 我正在使用 Python 3 7 开发 Amazon EC2 实例 Amazon linux 版本 2 AMI 我正在尝试使用以下命令安装 python Levenshtein 包 pip3 install python Levens
  • pip 安装最新的依赖版本

    当我使用安装包时pip install e 它仅安装不满足的依赖项并忽略依赖项升级 如何在每次运行时安装最新的依赖版本pip install e 我尝试过使用pip install upgrade e 但是使用这个选项没有任何改变 我仍然得
  • setColumnStretch 和 setRowStretch 如何工作

    我有一个使用构建的应用程序PySide2它使用setColumnStretch用于柱拉伸和setRowStretch用于行拉伸 它工作得很好 但我无法理解它是如何工作的 我参考了 qt 文档 但它对我没有帮助 我被困在括号内的两个值上 例如
  • 使用 pyppeteer 与 asyncio 关联来抓取内容

    我用 python 结合编写了一个脚本pyppeteer随着asyncio从其登陆页面抓取不同帖子的链接 并最终通过跟踪通向其内页的 url 来获取每个帖子的标题 我这里解析的内容不是动态的 但是 我利用了pyppeteer and asy
  • 关于具有自定义损失的 3 输出 ANN 的加权

    我正在尝试定义一个自定义损失函数 它在回归模型中接收 3 个输出变量 def custom loss y true y pred y true c K cast y true float32 Shape batch size 3 y pre
  • model.predict() 返回类而不是概率

    Hello 我是第一次使用 Keras 我训练并保存了一个模型 作为 json 文件及其权重 该模型旨在将图像分为 3 个类别 我的编译方法 model compile loss categorical crossentropy optim
  • 如果多个测试有特定异常,则停止 pytest 测试

    我想使用停止测试套件pytest exit 如果任何测试因特定异常而失败 例如 50 个测试 其中任何一个都可能在某个时刻因该异常而失败 如果这些测试中至少有 2 个测试因该异常而失败 我想停止执行 我试图保留一个全局计数器 一个固定装置s
  • Huggingface 变形金刚模块未被 anaconda 识别

    我正在使用 Anaconda python 3 7 Windows 10 我尝试通过安装变压器https huggingface co transformers https huggingface co transformers 在我的环境
  • 如何将 UPX 与 pyinstaller 一起使用?

    如何将 UPX 与 pyinstaller 一起使用 我正在关注文档 我已经下载了UPX 我的文件如下所示 import csv import selenium import pandas print Hello 然后我运行 pyinsta
  • 如何将嵌套的Python字典转换为简单的命名空间?

    假设我有一个深度为 N 的嵌套字典 如何将每个内部嵌套字典转换为简单的命名空间 example input key0a test key0b key1a key2a keyNx key2b test key1b test example o

随机推荐

  • HTML 中的方括号形成数组。只是传统的还是有意义的?

    我经常看到 特别是在 PHP 世界中 如果你想创建一个 FORM 数组 可以这样写
  • 如何将 Firebase 身份验证令牌传递给 webView 并在 Android 上注册通知

    我有一个 Firebase WebApp 它向用户提供信息 除了应用程序之外 我还需要通过 Firebase 云消息传递向使用 Android 应用程序的用户发送推送通知 目标 用户应该一次登录到应用程序 既可以注册通知 又可以通过 Web
  • Google 时间轴图表持续时间(以小时为单位)

    我正在使用 Google 时间线图表 即使持续时间超过一天 我也想以小时为单位显示持续时间 是否可以 谢谢 包含一千个样本的图像 展示了不同的行为1正如您所看到的 红色的持续时间是错误的 蓝色的持续时间是计算和打印的 没有配置选项更改工具提
  • 在哪里可以找到当前的 C 或 C++ 标准文档?

    这个问题的答案是社区努力 编辑现有答案以改进这篇文章 目前不接受新的答案或互动 对于许多问题 答案似乎可以在 标准 中找到 然而 我们在哪里可以找到它呢 最好是在线 谷歌搜索有时会让人感到徒劳 尤其是对于 C 标准 因为它们淹没在编程论坛上
  • C 中的外部指针和静态指针

    您好 静态和外部指针的用法是什么 如果它们存在的话 为了回答您关于何时可以使用它们的问题 举几个简单的例子 静态指针可用于实现始终向程序返回相同缓冲区的函数 并在第一次调用时分配它 char GetBuffer static char bu
  • Java错误:应该在名为[重复]的文件中声明

    这个问题在这里已经有答案了 我对 Java 相当陌生 并试图弄清楚如何解决以下错误 读取错误 CalculatorWithMemory java 1 class Calculator is public should be declared
  • 如何找到最近的标记 leaflet.js

    我想知道是否真的有某种方法可以使用 leaflet js 找到我位置附近的标记 我首先想到的是存储我所在位置的纬度和经度 然后迭代一系列纬度和经度标记 将它们放入一个数组中 然后对该数组进行排序 我不确定这是否是一个好的选择 因为如果地图上
  • 路由器在 NAT 中保留记录多长时间?这些记录可以重复使用来转发来自其他主机的请求吗?

    有一个答案以简单的方式解释了路由器如何将请求从本地网络转换到外部网络并返回 https superuser com questions 105838 how does router know where to forward packet
  • 在 winapi 中拖放

    我有一个纯 Winapi 应用程序 需要一些新功能 其中之一最好实现为两个列表 您可以在列表之间拖放 多个 元素 新功能可以仅限于单个对话框 实现这一点的最快方法是什么 一些想法 纯Winapi 是DetectDrag 提供这一对话框的单独
  • Gradle resValue 导致重复字符串资源

    我的 Android 清单文件定义应用程序名称如下 android label string app name res values strings xml 中存在 app name 的相应条目 现在 在我的 build gradle 中
  • 如何使 Python/Sphinx 文档对象属性仅在 __init__ 中声明?

    我有带有对象属性的 Python 类 这些属性仅在运行构造函数时声明 如下所示 class Foo object def init self base self basepath base temp for run in os listdi
  • 从 C# 中的枚举获取字符串名称

    我已经声明了一个枚举 如下所示 public enum State KARNATAKA 1 GUJRAT 2 ASSAM 3 MAHARASHTRA 4 GOA 5 从外部来源 我得到的状态值为 1 或 2 或 3 或 4 或 5 根据我得
  • 是什么让 C 比 Python 更快? [关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心以获得指导 我知道这可能是一个非常明显
  • C 宏 _Generic 给出意外的编译器错误

    使用 gcc exe Rev3 由 MSYS2 项目构建 8 2 0 我试图构建一个宏来自动在两种类型之间进行类型转换 其中两个参数永远不应该是相同的类型 我的问题是 如果我不包含相同类型的情况 编译器会抛出错误 我想要什么 include
  • 如何在signalR HubClass中使用UrlHelper

    我有一个从 Hub 驱动的聊天类 我想知道是否有一种方法可以通过 URLHelper 构建 URL 例如 Url Action action Controller 因为我可以从 2 个抽象类 集线器 控制器 派生该类 所以我不知道是否还有其
  • 如何设置适用于 Android 的 Google 云消息传递?

    我正在尝试实施Google Cloud Messaging for Android GCM 通过遵循demo 但我无法执行一些命令 例如 ant war android update project name GCMDemo p targe
  • iOS - 多次点击手势识别器

    在我的应用程序中 我必须检测单击 双击和三次点击 所以 我正在使用 UITapGestureRecognizer 我正在使用以下代码 UITapGestureRecognizer oneTap UITapGestureRecognizer
  • 调试 Sunspot 上的 Solr 搜索查询

    在 Rails 上使用 Sunspot gem 时如何调试 Solr 搜索查询 我有一些查询返回了异常高的分数 我试图弄清楚为什么会发生这种情况 似乎没有任何调试信息暴露给Sunspot 所以我认为我需要直接通过Solr进行调试 幸运的是
  • 文字闪烁 jQuery

    在 jQuery 中使文本闪烁的简单方法是什么以及停止它的方法是什么 必须适用于 IE FF 和 Chrome 谢谢 一个让某些文本闪烁的插件对我来说听起来有点矫枉过正 尝试这个 blink each function var elem t
  • 如何从 .h5 文件正确加载带有自定义层的 Keras 模型?

    我构建了一个带有自定义层的 Keras 模型 并将其保存到 h5通过回调文件ModelCheckPoint 当我在训练后尝试加载该模型时 出现以下错误消息 init missing 1 required positional argumen