使用 Keras 获取模型输出相对权重的梯度

2023-11-26

我对利用 Keras API 的简单性构建强化学习模型感兴趣。不幸的是,我无法提取输出相对于权重的梯度(不是误差)。我发现以下代码执行类似的功能(神经网络的显着图(使用 Keras))

get_output = theano.function([model.layers[0].input],model.layers[-1].output,allow_input_downcast=True)
fx = theano.function([model.layers[0].input] ,T.jacobian(model.layers[-1].output.flatten(),model.layers[0].input), allow_input_downcast=True)
grad = fx([trainingData])

关于如何计算模型输出相对于每层权重的梯度的任何想法将不胜感激。


要使用 Keras 获取模型输出相对于权重的梯度,您必须使用 Keras 后端模块。我创建了这个简单的示例来准确说明要做什么:

from keras.models import Sequential
from keras.layers import Dense, Activation
from keras import backend as k


model = Sequential()
model.add(Dense(12, input_dim=8, init='uniform', activation='relu'))
model.add(Dense(8, init='uniform', activation='relu'))
model.add(Dense(1, init='uniform', activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

为了计算梯度,我们首先需要找到输出张量。对于模型的输出(我最初的问题所问的),我们简单地称为 model.output。我们还可以通过调用找到其他层的输出梯度model.layers[index].output

outputTensor = model.output #Or model.layers[index].output

然后我们需要选择与梯度相关的变量。

  listOfVariableTensors = model.trainable_weights
  #or variableTensors = model.trainable_weights[0]

我们现在可以计算梯度。它很简单,如下所示:

gradients = k.gradients(outputTensor, listOfVariableTensors)

要实际运行给定输入的梯度,我们需要使用一些 Tensorflow。

trainingExample = np.random.random((1,8))
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
evaluated_gradients = sess.run(gradients,feed_dict={model.input:trainingExample})

就是这样!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 Keras 获取模型输出相对权重的梯度 的相关文章

随机推荐

  • 是否存在 SHA1(x) 等于 x 的 x?

    有没有一个x where SHA1 x x 我正在寻找证据或强有力的论据来反对它 与问题相同的论点适用于此有MD5定点吗 IE 对于随机选择的函数 该值约为 63
  • Lua表的一个有趣现象

    我是Lua新手 这几天正在学习table的用法 从教程中我知道Lua对待数字索引项和非数字索引项的方式不同 所以我自己做了一些测试 今天我发现一个有趣的现象 我无法解释它 The code t 1 2 3 a a b b print t g
  • android webview youtube 嵌入视频自动播放不起作用

    我无法自动播放我的视频 请帮忙 我的sdk版本 android minSdkVersion 14 android targetSdkVersion 19 gt 我尝试按照代码中指定的方式放置 JavaScript public void o
  • for循环中分号放错位置[重复]

    这个问题在这里已经有答案了 当我做作业时 我犯了一个小错误 在 for 循环中像下面的代码一样 for i 0 i
  • 如何从表单外部捕获表单的某些事件?

    我正在做一些需要监控多种表格的事情 从表单外部 并且不将任何代码放入表单内 我需要以某种方式从这些表单捕获事件 很可能以 Windows 消息的形式 但是 如何从与其相关的类外部捕获 Windows 消息呢 我的项目有一个对象 它包装了它正
  • 如何设置 C++ 函数以便 p/invoke 使用它?

    希望这是一个无脑简单的问题 但这表明我缺乏 C 专业知识 我是一名 C 程序员 过去我使用 P Invoke 和其他人的 C C dll 进行了大量工作 然而 这次我决定自己编写一个包装器 C dll 非托管 然后从 C 调用我的包装器 d
  • 如何使用 Meteor.js 对 Dropbox API 进行 CURL 调用

    我是 Meteor js 新手 希望让我的 Web 应用程序能够与 Dropbox Core API 配合使用 我无法全神贯注于使用 Meteor js 中的 HTTP 包进行 API 调用 如何在 Meteor 中进行类似于下面的 Cur
  • 将常数(2 的幂)除以整数的技巧

    NOTE这是一个理论问题 我对实际代码的性能感到满意 我只是好奇是否有替代方案 有没有一种技巧可以将常量值 本身是 2 的整数幂 除以整数变量值 而无需使用实际的除法运算 The fixed value of the numerator d
  • Ngit 与私钥文件建立连接

    我正在尝试使用 NGit 连接到 Github 即使用私钥和密码 有人可以引导我完成它吗 我的正常获取是 var git Git CloneRepository SetDirectory properties OutputPath SetU
  • 如何从 Vue Composition API / Vue 3.0 + TypeScript 中的组合函数访问根上下文?

    我想创建可重用的包装函数写在打字稿用于通过使用触发 toast 通知复合函数 如 Vue 3 0 的当前规范中所定义 组合 API RFC 此示例使用 BootstrapVue v2 0 toast 组件 对于 Vue 2 它将通过以下方式
  • 将设置保留在数据库中

    在可重用的应用程序中 我不想更改任何代码 我想更改应用程序使用的设置变量 以其形式和其他部分 为动态的 从数据库表更新其内容 最好的方法是什么 也许是中间件 看看Django 数据库设置项目
  • 如何从 numpy 数组生成音频?

    我想从 numpy 中的 2D 数组创建 心率监视器 效果 并希望音调反映数组中的值 您可以使用write功能 from scipy io wavfile创建一个 wav 文件 然后您可以随意播放该文件 请注意 数组必须是整数 因此如果有浮
  • RuntimeError:Matplotlib 动画中没有可用的 MovieWriters

    我遇到的问题是类似于此示例的代码 https matplotlib org examples animation basic example writer html 错误 运行时错误 没有可用的 MovieWriters发生在Writer
  • 赋予 PHP include() 文件父变量作用域

    无论如何 是否可以在调用它的父范围中使用包含的文件 以下示例经过简化 但完成相同的工作 本质上 文件将被函数包含 但希望包含的文件的范围是调用包含该文件的函数的范围 main php
  • JsonEditor 与 Django Admin 集成

    我正在努力整合JSON编辑器进入 Django 管理员 我的模型中有一个字段使用 Postgres JSON 并且该库中的树编辑器非常完美 模型 py class Executable models Model Simplified mod
  • 如何以编程方式创建 jms Topic 和 TopicConnectionFactory?

    有人知道是否可以以编程方式创建主题及其连接工厂吗 目前 我使用 glassfish 管理实用程序来创建我的主题及其连接工厂 如果我无法在代码中创建它 glassfish openmq 是否有我可以使用的默认主题和 conn 工厂 如果您只想
  • Chrome 64 未捕获 DOMException:无法在“CSSStyleSheet”上执行“insertRule”:无法访问 StyleSheet 来 insertRule

    啊 我的网站在 Chrome 中损坏了 在控制台中获取此消息 Uncaught DOMException Failed to execute insertRule on CSSStyleSheet Cannot access StyleSh
  • jQuery ajax 中有进度更新事件吗?

    我有一个长时间运行的任务 使用 jquery ajax 调用 我正在使用阻止 ui 插件显示 加载中 无论如何 我可以将进度消息发送回客户端以显示进度 并在块 ui 插件消息上更新该进度 所以它会显示这一点 当服务器完成其工作时 正在加载第
  • AutoMapper.dll 中发生“AutoMapper.AutoMapperMappingException”类型的异常,但未在用户代码中处理

    不知何故 我的代码不再工作 它之前使用完全相同的代码确实可以工作 这就是问题 The code 我正在尝试使用以下代码将一些对象映射到 ViewModel 配置 Mapper CreateMap
  • 使用 Keras 获取模型输出相对权重的梯度

    我对利用 Keras API 的简单性构建强化学习模型感兴趣 不幸的是 我无法提取输出相对于权重的梯度 不是误差 我发现以下代码执行类似的功能 神经网络的显着图 使用 Keras get output theano function mod