在 TensorFlow 张量上调用 Keras 模型但保留权重

2023-11-22

In Keras 作为 TensorFlow 的简化接口:教程他们描述了如何在 TensorFlow 张量上调用 Keras 模型。

from keras.models import Sequential

model = Sequential()
model.add(Dense(32, activation='relu', input_dim=784))
model.add(Dense(10, activation='softmax'))

# this works! 
x = tf.placeholder(tf.float32, shape=(None, 784))
y = model(x)

他们还说:

注意:通过调用 Keras 模型,您可以重用其架构和权重。当您在张量上调用模型时,您将在输入张量之上创建新的 TF 操作,并且这些操作将重用模型中已存在的 TF 变量实例。

我将其解释为模型的权重在y如模型中所示。然而,对我来说,生成的 Tensorflow 节点中的权重似乎已重新初始化。下面是一个最小的例子:

import numpy as np
from keras.models import Sequential
from keras.layers import Dense
# Create model with weight initialized to 1
model = Sequential()
model.add(Dense(1, input_dim=1, kernel_initializer='ones',
                bias_initializer='zeros'))
model.compile(loss='binary_crossentropy', optimizer='adam',
              metrics=['accuracy'])

# Save the weights 
model.save_weights('file')

# Create another identical model except with weight initialized to 0
model2 = Sequential()
model2.add(Dense(1, input_dim=1, kernel_initializer='zeros',
                 bias_initializer='zeros'))
model2.compile(loss='binary_crossentropy', optimizer='adam',
               metrics=['accuracy'])
# Load the weight from the first model
model2.load_weights('file')
# Call model with Tensorflow tensor
v = tf.Variable([[1, ], ], dtype=tf.float32)
node = model2(v)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(node), model2.predict(np.array([[1, ], ])))
# Prints (array([[ 0.]], dtype=float32), array([[ 1.]], dtype=float32))

为什么我要这样做:

我想在另一个最小化方案中使用经过训练的网络,网络“惩罚”搜索空间中不允许的位置。因此,如果您有不涉及这种特定方法的想法,我们也非常感激。


终于找到了答案。问题中的示例有两个问题。

1:

第一个也是最明显的是我称之为tf.global_variables_intializer()函数将重新初始化会话中的所有变量。相反,我应该打电话给tf.variables_initializer(var_list) where var_list是要初始化的变量列表。

2:

第二个问题是 Keras 没有使用与本机 Tensorflow 对象相同的会话。这意味着能够运行张量流对象model2(v)与我的会议sess它需要重新初始化。再次Keras 作为张量流的简化接口:教程能够提供帮助

我们应该首先创建一个 TensorFlow 会话并将其注册到 Keras。这意味着 Keras 将使用我们注册的会话来初始化它内部创建的所有变量。

import tensorflow as tf
sess = tf.Session()

from keras import backend as K
K.set_session(sess)

如果我们将这些更改应用于我的问题中提供的示例,我们将得到以下代码,该代码完全符合预期。

from keras import backend as K
from keras.models import Sequential
from keras.layers import Dense
sess = tf.Session()
# Register session with Keras
K.set_session(sess)
model = Sequential()
model.add(Dense(1, input_dim=1, kernel_initializer='ones',
                bias_initializer='zeros'))
model.compile(loss='binary_crossentropy', optimizer='adam',
              metrics=['accuracy'])
model.save_weights('test')

model2 = Sequential()
model2.add(Dense(1, input_dim=1, kernel_initializer='zeros',
                 bias_initializer='zeros'))
model2.compile(loss='binary_crossentropy', optimizer='adam',
               metrics=['accuracy'])
model2.load_weights('test')
v = tf.Variable([[1, ], ], dtype=tf.float32)
node = model2(v)
init = tf.variables_initializer([v, ])
sess.run(init)
print(sess.run(node), model2.predict(np.array([[1, ], ])))
# prints: (array([[ 1.]], dtype=float32), array([[ 1.]], dtype=float32))

结论:

教训是,在混合 Tensorflow 和 Keras 时,请确保所有内容都使用相同的会话。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 TensorFlow 张量上调用 Keras 模型但保留权重 的相关文章

随机推荐

  • Apache Spark 处理倾斜数据

    我有两张桌子想连接在一起 其中之一的数据偏差非常严重 这导致我的 Spark 作业无法并行运行 因为大部分工作都是在一个分区上完成的 我听过 读过并尝试对我的密钥进行加盐以增加分发 https www youtube com watch v
  • “使用警告”与“#!/usr/bin/perl -w”有区别吗?

    我读到最好use warnings 而不是放置一个 w在shebang的最后 两者有什么区别 警告编译指示是命令行标志 w 的替代品 但编译指示仅限于封闭块 而标志是全局的 看佩勒克斯警告了解更多信息和内置警告类别列表 warnings文档
  • 为什么“find . -name *.txt | xargs du -hc”给出多个总计?

    我有一大堆目录 我正在尝试计算其中数百个 txt 文件的总大小 我尝试过这个 大部分有效 find name txt xargs du hc 但最后我没有给我一个总数 而是得到了几个 我的猜测是 管道一次只会传递这么多行 find 的输出
  • Cypress:存根打开窗口

    在我的应用程序中有一个推荐列表 单击该列表会打开一个带有动态地址的新窗口 window open shopURL blank 现在我正在尝试存根 windows open 事件 如下所示https github com cypress io
  • 如何查找AWS S3存储桶中的重复文件?

    有没有办法在 Amazon S3 存储桶中递归查找重复文件 在普通文件系统中 我只需使用 fdupes r my directory Amazon S3 中没有 查找重复项 命令 但是 您确实执行以下操作 检索一个对象列表在桶里 寻找具有以
  • 如何返回至少 4D 的数组:模拟 numpy.atleast_4d 的高效方法

    numpy 提供了三个方便的例程来将数组转换为至少 1D 2D 或 3D 数组 例如通过numpy atleast 3d 我需要多一维的等价物 atleast 4d 我可以想到使用嵌套 if 语句的各种方法 但我想知道是否有更有效和更快的方
  • 我可以从 PowerShell 访问我的自定义 .NET 类吗?

    我对 PowerShell 和 NET 类有一些问题和疑问 我正在尝试编写一个 foo 类 它将调用 Rest Web 服务并执行一些任务 如果我在 GAC 中部署该类 那么我可以从 PowerShell 调用它吗 Try ADD TYPE
  • 此版本的 Realm 不支持在 Realm Studio 中打开格式版本 11 的 Realm 文件

    我正在使用 React Native 在此我指的是this在 React Native 中使用领域数据库的文档 我可以创建react native数据库 但无法在Realm Studio V3 11 0中打开它 当我在工作室中打开保存的 R
  • 使用 JavaScript 创建 HTML 文件

    我正在寻找一种使用 JavaScript 函数在本地目录中创建新 HTML 文件的方法 这可能吗 Thanks 客户端 是的 但您可能需要创建一个新的 ActiveX 对象 因此浏览器只能是 IE 服务器端 只需使用任何服务器端脚本语言 J
  • 确定枚举值是否在列表中 (C#)

    我正在构建一个有趣的小应用程序来确定我是否应该骑自行车上班 我想测试一下是下雨还是雷雨 public enum WeatherType byte Sunny 0 Cloudy 1 Thunderstorm 2 Raining 4 Snowi
  • 将 csv 数据加载到 Hive 表时出错

    我在 hadoop 中有一个 csv 文件 并且有一个 Hive 表 现在我想将该 csv 文件加载到此 Hive 表中 我已使用 load LOAD DATA local path to csv file 覆盖 INTO TABLE 表名
  • React Native,TouchableOpacity 包裹浮动按钮什么也得不到

    我正在创建一个简单的操作按钮 浮动按钮 这是工作
  • 如何保护dll?

    如何保护我的项目的dll 使其不被其他人引用和使用 Thanks 简而言之 除了显而易见的事情之外 您无能为力 您可能需要考虑的明显事情 大致按照难度增加和合理性降低的顺序 包括 静态链接 因此没有 DLL 可供攻击 删除所有符号 使用 D
  • javascript:带有 html 标签的 focusOffset

    我有一个 contenteditable div 如下 光标位置 div lorem ipsum div
  • Python 列表理解,具有独特的项目

    有没有办法在 Python 中创建仅包含唯一项的列表理解 我最初的想法是使用这样的东西 new items unicode item for item in items 然而 我后来意识到我需要省略重复的项目 所以我最终得到了这个丑陋的怪物
  • android - 如何使用 achartengine 更改图表的背景颜色

    我使用 achartengine 实现了折线图 但我想改变折线图的背景颜色 有人建议使用以下代码来更改背景颜色 mRenderer setApplyBackgroundColor true mRenderer setBackgroundCo
  • 用于创建尚不存在的内容的函数名称

    我有时会编写一个函数 如果尚不存在 则只创建一些东西 否则不执行任何操作 名字像CreateFooIfNecessary or EnsureThereIsAFoo 做工作 但他们感觉有点笨拙 也可以说GetFoo 但这个名字并不意味着foo
  • 捕获另一个表单抛出的异常

    我正在尝试这样做 我正在创建另一个表单 它的 FormClosed 方法会抛出一个异常 该异常应该由主表单捕获 主要形式 try frmOptions frm new frmOptions frm ShowDialog catch Exce
  • 循环内的 JavaScript 闭包 – 简单的实际示例

    var funcs let s create 3 functions for var i 0 i lt 3 i and store them in funcs funcs i function each should log its val
  • 在 TensorFlow 张量上调用 Keras 模型但保留权重

    In Keras 作为 TensorFlow 的简化接口 教程他们描述了如何在 TensorFlow 张量上调用 Keras 模型 from keras models import Sequential model Sequential m