重置 Keras 层中的权重

2024-02-20

我想重置(随机化)Keras(深度学习)模型中所有层的权重。原因是我希望能够使用不同的数据分割多次训练模型,而不必每次都进行(缓慢的)模型重新编译。

灵感来自这次讨论 https://github.com/fchollet/keras/pull/1908,我正在尝试以下代码:

# Reset weights
for layer in KModel.layers:
    if hasattr(layer,'init'):
        input_dim = layer.input_shape[1]
        new_weights = layer.init((input_dim, layer.output_dim),name='{}_W'.format(layer.name))
        layer.trainable_weights[0].set_value(new_weights.get_value())

然而,它只起到了部分作用。

部分原因是我检查了一些 layer.get_weights() 值,它们似乎发生了变化。但是当我重新开始训练时,成本值远低于第一次运行时的初始成本值。这几乎就像我已经成功地重置了一些权重,但不是全部。


在编译模型之后、训练之前立即保存初始权重:

model.save_weights('model.h5')

然后在训练后,通过重新加载初始权重来“重置”模型:

model.load_weights('model.h5')

这为您提供了一个同类模型来比较不同的数据集,并且应该比重新编译整个模型更快。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

重置 Keras 层中的权重 的相关文章

  • 带有指针数组的 cython

    我在 python 中有一个 numpy ndarrays 列表 具有不同的长度 并且需要非常快速地访问 python 中的列表 我认为指针数组就可以解决问题 我试过 float type t list of arrays no of ar
  • Python 小数.InvalidOperation 错误

    当我运行这样的东西时 我总是收到此错误 from decimal import getcontext prec 30 b 2 3 Decimal b Error Traceback most recent call last File Te
  • Django 如何从 ManyToManyField 序列化并列出全部

    我正在使用 Django 1 9 1 开发移动应用程序后端 我实现了关注者模型 现在我想列出用户的所有关注者 但目前我不得不这样做 我还使用 Django Rest 框架 这是我的 UserProfile 模型 class UserProf
  • Python 使用 M2Crypto 通过 S/MIME 对消息进行签名

    我现在花了几个小时 但找不到我的错误 我想要一个简单的例程来创建 S MIME 签名消息 稍后可以与 smtplib 一起使用 这是我到目前为止所拥有的 usr bin python2 7 coding utf 8 from future
  • 在 keras 中使用自定义张量流操作

    我在张量流中有一个脚本 其中包含自定义张量流操作 我想将代码移植到 keras 但我不确定如何在 keras 代码中调用自定义操作 我想在 keras 中使用tensorflow 所以到目前为止我发现的教程描述了与我想要的相反的内容 htt
  • Python 字典 - 在 2 个字符的字符串中查找第二个字符,该字符产生最小值

    我想提交密钥的第一部分并返回该密钥的剩余部分 以最小化值 并从第一部分开始 例如 d ab 100 ac 200 ad 500 如果我要进去 a I would like to return b min d s s for s in d i
  • 理解@property装饰器和继承[重复]

    这个问题在这里已经有答案了 这里是 Python 3 以防万一它很重要 我试图正确理解如何实现继承 property使用 我已经搜索了 StackOverflow 并阅读了大约 20 个类似的问题 但无济于事 因为他们试图解决的问题略有不同
  • 使用sklearn进行多标签特征选择

    我希望使用 sklearn 对多标签数据集执行特征选择 我想要获得最终的功能集across标签 然后我将在另一个机器学习包中使用它 我打算使用我看到的方法here https stackoverflow com questions 1640
  • 当 DetailView 遇到时更新模型字段。 [姜戈]

    我有一个类似的 DetailViewviews py views py class CustomView DetailView context object name content model models AppModel templa
  • 如何使用 jira-python 设置 fixVersions 字段

    我正在尝试使用 jira python 模块 http jira python readthedocs org en latest 更新现有的 JIRA 具体来说 我正在尝试设置问题的fixesVersion 列表 我已经尝试了一段时间但没
  • 如何在matplotlib中基于x轴更改直方图颜色

    我有根据 pandas 数据框计算出的直方图 我想根据 x 轴值更改颜色 例如 If the value is 0 the color should be green If the value is gt 0 the color shoul
  • 如何在 Python 中仅列出 zip 存档中的文件夹?

    如何仅列出 zip 存档中的文件夹 这将列出存档中的每个文件夹和文件 import zipfile file zipfile ZipFile samples sample zip r for name in file namelist pr
  • 如何在 Python 中执行相当于预处理器指令的操作?

    有没有办法在 Python 中执行以下预处理器指令 if DEBUG lt do some code gt else lt do some other code gt endif There s debug 这是编译器预处理的特殊值 if
  • PyQt - 如何检查 QDialog 是否可见?

    我有个问题 我有这个代码 balls Ball for i in range 1 10 因此 当我说 Ball 时 这将在 QDialog 上绘制一个球 然后当这完成后 我正在移动球QDialog无限循环中 我想说类似的话while QDi
  • 如何使用数据库在 Django 中的应用程序之间交换数据?

    我正在使用 Django 在网络上工作 我创建了 2 个应用程序 第一个用于客户端注册并将其数据添加到数据库 第二个应用程序供用户访问和查看交互界面 这个想法是使用第二个应用程序从数据库中的客户端获取数据 并使用它向用户显示一些信息 我的问
  • 使用 Flask-SQLAlchemy 进行多对多多数据库连接

    我正在尝试使这个多对多联接与 Flask SQLAlchemy 和两个 MySQL 数据库一起工作 并且它非常接近 只是它为联接表使用了错误的数据库 这是基础知识 我有main db and vendor db 表格设置为main db u
  • 网页抓取 - 如何识别网页上的主要内容

    给定一个新闻文章网页 来自任何主要新闻来源 例如时报或彭博社 我想识别该页面上的主要文章内容 并丢弃其他杂项元素 例如广告 菜单 侧边栏 用户评论 在大多数主要新闻网站上都可以使用的通用方法是什么 有哪些好的数据挖掘工具或库 最好是基于Py
  • 如何让你的精灵在pygame中跳跃

    目前我已经制作了一个平台游戏 可以左右移动我的角色 他从地上开始 关于如何让他跳的任何想法 因为我不明白 目前 如果我按住向上键 我的玩家精灵将连续向上移动 或者如果我按下它 我的玩家精灵将向上移动并保持向上 我想找个办法远离他 让我重新跌
  • 使 matplotlib 图形默认看起来像 R?

    Is there a way to make matplotlib behave identically to R or almost like R in terms of plotting defaults For example R t
  • 如何动态创建 Luigi 任务

    我正在为 Luigi Tasks 构建一个包装器 但遇到了一个障碍Register http luigi readthedocs io en stable modules luigi task register html Register该

随机推荐