训练时如何获得层权重?

2024-03-18

我有一个模型,我想获取特定层的权重矩阵,以便在定义自定义损失函数时使用它。

有没有办法获得模型内部特定层的权重?

附:我目前正在使用 TensorFlow 2 和 keras 功能 API。我测试过如何获取 Keras 中图层的权重? https://stackoverflow.com/questions/43715047/how-do-i-get-the-weights-of-a-layer-in-keras方法,但没有成功。

附言通过使用上述方法,我收到以下错误:

AttributeError                            Traceback (most recent call last)
<ipython-input-26-e0bd481102a7> in <module>
      1 A_DENSE = Dense(1, use_bias = True, name = "A_DENSE")(INPUT)
----> 2 A_DENSE.get_weights()

AttributeError: 'Tensor' object has no attribute 'get_weights'

PPPS如下回答,结合自定义回调和 get_weights 可以解决问题。祝那些和我有类似情况的人好运。


你可以写一个自定义的Callback并在每次纪元结束时使用它。我展示它是为了打印权重,但您可以将它用作自定义损失的一部分。

class CustomCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        rand_int = tf.random.uniform((), 0, 2, dtype=tf.int32)
        print(rand_int)
        
model.fit(X, y epochs = 10, batch_size = 20, validation_split=0.1, callbacks=[CustomCallback()])

更多细节here https://www.tensorflow.org/guide/keras/custom_callback.


例如,这是一个用于打印的虚拟代码weights and biases of layer[1]每个纪元之后。您可以按照您喜欢的方式设置该功能。

from tensorflow.keras import layers, Model, callbacks

class CustomCallback(callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(' ')
        print(' ')
        print(model.layers[1].get_weights())
        

X, y = np.random.random((10,5)), np.random.random((10,))

inp = layers.Input((5,))
x = layers.Dense(3)(inp)
out = layers.Dense(1)(x)

model = Model(inp, out)

model.compile(loss='MAE',metrics=['accuracy'])
model.fit(X,y,callbacks=[CustomCallback()], epochs=3)
Epoch 1/3
1/1 [==============================] - ETA: 0s - loss: 0.2346 - accuracy: 0.0000e+00 
 
[array([[ 0.16518219, -0.44628695, -0.07702655],
       [-0.1993848 ,  0.03855793, -0.62964785],
       [ 0.5592851 , -0.28281152, -0.23358124],
       [ 0.05242977,  0.4023881 , -0.19522922],
       [ 0.07936202, -0.40436065,  0.10003945]], dtype=float32), array([ 0.01530731, -0.01565045, -0.01581042], dtype=float32)]
1/1 [==============================] - 0s 2ms/step - loss: 0.2346 - accuracy: 0.0000e+00
Epoch 2/3
1/1 [==============================] - ETA: 0s - loss: 0.2337 - accuracy: 0.0000e+00 
 
[array([[ 0.16814367, -0.4492649 , -0.08000461],
       [-0.19710523,  0.03622784, -0.6319782 ],
       [ 0.55797213, -0.28144714, -0.23221655],
       [ 0.05509637,  0.3996864 , -0.19793113],
       [ 0.07731982, -0.40226308,  0.10213734]], dtype=float32), array([ 0.01846951, -0.01881272, -0.01897269], dtype=float32)]
1/1 [==============================] - 0s 7ms/step - loss: 0.2337 - accuracy: 0.0000e+00
Epoch 3/3
1/1 [==============================] - ETA: 0s - loss: 0.2322 - accuracy: 0.0000e+00 
 
[array([[ 0.16706704, -0.448164  , -0.07889817],
       [-0.19894598,  0.0381193 , -0.63007975],
       [ 0.5558067 , -0.27921563, -0.22997847],
       [ 0.05663134,  0.3981127 , -0.19951159],
       [ 0.07536169, -0.400249  ,  0.10415838]], dtype=float32), array([ 0.01846951, -0.01881272, -0.01897269], dtype=float32)]
1/1 [==============================] - 0s 2ms/step - loss: 0.2322 - accuracy: 0.0000e+00
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

训练时如何获得层权重? 的相关文章

随机推荐

  • 修改ggplot2 Y轴以使用整数而不强制执行上限[重复]

    这个问题在这里已经有答案了 我正在尝试修改 ggplot2 中的轴 以便它是一个小数点 并且每个整数都有一个标签 但是 我想这样做没有上限 以便它会自动调整不同计数的数据 我的问题和问题之间的区别这里提出的问题 https stackove
  • 在 Linux 中的类路径上使用 javac 和多个特定的 jar(波浪号在冒号后不扩展)

    我正在尝试通过类似于以下的命令编译一个使用两个 jar 文件 trove 和 apache commons 集合 的 java 源文件 javac cp m2 repository gnu trove trove 3 0 0 trove 3
  • 如何获取ASCII后面的二进制代码(C#)

    我试图找出如何将控制台的输入转换为二进制 如何在 C 中进行这样的转换 先感谢您 string s Console ReadLine byte bytes Encoding ASCII GetBytes s 请注意 控制台使用的编码实际上不
  • pgsql 返回表错误:列引用不明确

    我不断收到此错误 列引用 人 不明确 我需要返回一个表 个人整数 当我使用 SETOF 整数时它工作正常 但在这种情况下它不起作用 我的另一个函数 recurse 完美地返回一组整数 CREATE OR REPLACE FUNCTION t
  • 来自基类的用户定义转换运算符

    介绍 我知道 不允许用户定义的与基类之间的转换 MSDN 给出了对此规则的解释 你不需要这个运算符 我确实了解用户定义的转换to不需要基类 因为这显然是隐式完成的 但是 我确实需要转换from一个基类 在我当前的设计 非托管代码的包装器 中
  • Matplotlib 轴位置和颜色条对齐

    我正在尝试将多个颜色条与使用其中之一生成的子图对齐gridspec or fig add subplots 我想添加颜色条fig add axes在 matplotlib v2 02 中 因为它允许详细的对齐控制 但是 我需要获取图形位置才
  • 熊猫从日期时间列中获取第二个最小值[重复]

    这个问题在这里已经有答案了 我有一个带有日期时间列的数据框 我可以通过使用获得最小值 df Date min 我怎样才能得到第二个 第三个 最小值 Use nlargest or nsmallest 对于第二大的 series nlarge
  • Symfony2 和控制器中的 DRY 方法

    我正在使用 Symfony2 为我的公司开发一个小型 CMS 我真的很喜欢这个框架 我喜欢表单类并重用它们 毕竟这都是关于表单的 但是 是的 有一个 但是 我感觉我在做同样的事情 复制并粘贴到所有控制器中 我们讨厌的代码重复 随着所有业务逻
  • MbUnit:比较双打的最优雅的方式?

    The code Assert AreEqual 9 97320998018748d observerPosition CenterLongitude produces Expected Value Actual Value 9 97320
  • 安全地向 RESTFUL API 提供凭据

    我创建了一个 RESTful 服务器应用程序 它可以在有用的 URL 例如 www site com get someinfo 上处理请求并提供服务 它是在春天建造的 但是 这些访问受密码保护 我现在正在构建一个客户端应用程序 它将连接到这
  • Angular 6 生产版本“无法绑定到‘disabled’,因为它不是‘div’的已知属性”

    我的应用程序在使用 JIT 编译器时似乎可以工作 但是当我尝试使用 AOT 编译器时ng build prod然后它抛出一个错误 ERROR in Can t bind to disabled since it isn t a known
  • 很难理解express.js中的“next/next()”

    这是一个例子 Configuration app configure function app set views dirname views app set view engine jade app use express bodyPar
  • 在asp.net mvc中通过slug进行路由

    我有一个控制器操作 如下所示 public ActionResult Content string slug var content contentRepository GetBySlug slug return View content
  • navigator.webkitPersistentStorage.requestQuota 是否适用于 IndexedDB?

    使用今天最新版本的 Android Chrome 我可以使用以下命令请求持久性 IndexedDB 存储吗 navigator webkitPersistentStorage requestQuota var requestedBytes
  • 使用 Liquid 按字母顺序对帖子进行排序

    有没有办法使用 Jekyll 按字母顺序对多个帖子进行排序 我现在有这样的事情 for post in site categories threat li a href post title a li endfor 它有效 但帖子很混乱 我
  • 为什么在这种情况下重写不改变表达式的类型?

    在 1 中可以阅读下一篇 rewrite prf in expr 如果我们有prf x y 并且 expr 所需的类型是以下属性x the rewrite in语法将搜索x在所需的类型中expr并将其替换为y 现在 我有下一段代码 您可以将
  • Microsoft Botframework:与 Bot 通道直接对话

    我一直在努力从 C 控制台应用程序向托管在 Azure 中的 Skype 机器人发送直接消息 但我不断收到错误 操作返回无效的状态代码 未经授权 但我提供了以下凭据 Web 配置文件
  • 输入密码后启动 Shiny 应用程序(使用 Shinydashboard)

    In this topic https stackoverflow com questions 28987622 starting shiny app after password input rq 1很好地解释了如何在输入密码后启动shi
  • 为什么使用LabVIEW? [关闭]

    很难说出这里问的是什么 这个问题是含糊的 模糊的 不完整的 过于宽泛的或修辞性的 无法以目前的形式得到合理的回答 如需帮助澄清此问题以便重新打开 访问帮助中心 help reopen questions 我正在学习使用 LabVIEW 作为
  • 训练时如何获得层权重?

    我有一个模型 我想获取特定层的权重矩阵 以便在定义自定义损失函数时使用它 有没有办法获得模型内部特定层的权重 附 我目前正在使用 TensorFlow 2 和 keras 功能 API 我测试过如何获取 Keras 中图层的权重 https