构建 keras 模型

2024-02-10

我不明白这段代码中发生了什么:

def construct_model(use_imagenet=True):
    # line 1: how do we keep all layers of this model ?
    model = keras.applications.InceptionV3(include_top=False, input_shape=(IMG_SIZE, IMG_SIZE, 3),
                                          weights='imagenet' if use_imagenet else None) # line 1: how do we keep all layers of this model ?

    new_output = keras.layers.GlobalAveragePooling2D()(model.output)

    new_output = keras.layers.Dense(N_CLASSES, activation='softmax')(new_output)
    model = keras.engine.training.Model(model.inputs, new_output)
    return model

具体来说,我的困惑是,当我们调用最后一个构造函数时

model = keras.engine.training.Model(model.inputs, new_output)

我们指定输入层和输出层,但它如何知道我们希望所有其他层保留?

换句话说,我们将 new_output 层附加到我们在第 1 行加载的预训练模型,即 new_output 层,然后在最终的构造函数(最后一行)中,我们只需创建并返回一个具有指定输入和的模型输出层,但它如何知道我们想要在中间放置哪些其他层?

附带问题1):keras.engine.training.Model 和 keras.models.Model 有什么区别?

附带问题 2):当我们执行 new_layer = keras.layers.Dense(...)(prev_layer) 时到底会发生什么? () 操作是否返回新层,它到底做了什么?


该模型是使用功能性API模型 https://keras.io/getting-started/functional-api-guide/

基本上它的工作原理是这样的(也许如果你在阅读本文之前转到下面的“附带问题2”,它可能会变得更清楚):

  • 你有一个输入张量(您也可以将其视为“输入数据”)
  • 您创建(或重用)图层
  • 您将输入张量传递给一个层(您用输入“调用”一个层)
  • 你得到一个输出张量

您继续使用这些张量,直到创建了整个张量graph.

但这还没有创建一个“模型”。 (你可以训练和使用其他东西)。
你所拥有的只是一张图表,告诉你哪些张量去哪里。

要创建模型,您需要定义其起点和终点。


在例子中。

  • 他们采用现有模型:model = keras.applications.InceptionV3(...)
  • 他们想要扩展这个模型,所以他们得到了它输出张量: model.output
  • 他们将此张量作为输入GlobalAveragePooling2D
  • 他们得到该层的输出张量为new_output
  • 他们将其作为输入传递给另一层:Dense(N_CLASSES, ....)
  • 并得到它的输出new_output(这个变量被替换,因为他们对保留其旧值不感兴趣......)

但是,由于它与函数式 API 一起使用,我们还没有模型,只有图表。为了创建模型,我们使用Model定义输入张量和输出张量:

new_model = Model(old_model.inputs, new_output)    

现在你有了你的模型。
如果你像我一样在另一个变量中使用它(new_model),旧模型仍然存在model。这些模型共享相同的层,每当你训练其中一个模型时,另一个模型也会更新。


问题:它如何知道我们想要在中间添加哪些其他层?

当你这样做时:

outputTensor = SomeLayer(...)(inputTensor)    

输入和输出之间有连接。 (Keras 将使用内部张量流机制并将这些张量和节点添加到图中)。如果没有输入,输出张量就不可能存在。整个InceptionV3模型从头到尾都是连接的。它的输入张量经过所有层以产生输出张量。数据遵循的方式只有一种可能,而图表就是方式。

当您获得该模型的输出并使用它来获得进一步的输出时,所有新输出都将连接到此模型,从而连接到模型的第一个输入。

大概是属性_keras_history添加到张量中的值与其跟踪图的方式密切相关。

所以,做Model(old_model.inputs, new_output)自然会遵循唯一可能的方式:图表。

如果您尝试使用未连接的张量执行此操作,您将收到错误。


附带问题1

更喜欢从“keras.models”导入。基本上,该模块将从其他模块导入:

  • https://github.com/keras-team/keras/blob/master/keras/models.py https://github.com/keras-team/keras/blob/master/keras/models.py

请注意该文件keras/models.py进口Model from keras.engine.training。所以,这是同样的事情。

附带问题2

它不是new_layer = keras.layers.Dense(...)(prev_layer).

It is output_tensor = keras.layers.Dense(...)(input_tensor).

你在同一行做两件事:

  • 创建一个图层 - 使用keras.layers.Dense(...)
  • 使用输入张量调用层以获得输出张量

如果您想使用具有不同输入的同一层:

denseLayer = keras.layers.Dense(...) #creating a layer

output1 = denseLayer(input1)  #calling a layer with an input and getting an output
output2 = denseLayer(input2)  #calling the same layer on another input
output3 = denseLayer(input3)  #again   

奖励 - 创建一个与顺序模型相同的功能模型

如果您创建此顺序模型:

model = Sequential()
model.add(Layer1(...., input_shape=some_shape))   
model.add(Layer2(...))
model.add(Layer3(...))

你所做的与以下完全相同:

inputTensor = Input(some_shape)
outputTensor = Layer1(...)(inputTensor)
outputTensor = Layer2(...)(outputTensor)    
outputTensor = Layer3(...)(outputTensor)

model = Model(inputTensor,outputTensor)

有什么不同?

嗯,函数式 API 模型是完全免费的,可以按照您想要的方式构建。您可以创建分支:

out1 = Layer1(..)(inputTensor)    
out2 = Layer2(..)(inputTensor)

您可以加入张量:

joinedOut = Concatenate()([out1,out2])   

有了这个,您可以创建anything你想要各种奇特的东西,分支,门,串联,添加等等,这是顺序模型无法做到的。

事实上,一个Sequential模型也是一个Model,但创建它是为了在没有分支的模型中快速使用。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

构建 keras 模型 的相关文章

随机推荐

  • 如何动态添加和扩展私有数据集合?

    设想 I have 3个组织 O1 O3 O1 是申请人的组织 O2 O3 管理与他们共享的公共和私人数据 O1 O3 彼此共享私有数据 O1 O2 共享私有数据 网络正在运行 集合已经分发 一切正常 当我现在想要添加更多组织 以千计 O4
  • 为什么从另一个文件导入类会调用 __init__ 函数?

    该项目的结构是 project 主 py 会话 py 蜘蛛 py session py中有一个类 import requests class Session def init self self session requests Sessi
  • 如何通过反射找出方法的可见性?

    Context 我正在尝试学习 练习 TDD 并决定我需要创建一个不可变的类 为了测试 不变性不变量 你能这么说吗 我想我只需通过反射调用类中的所有公共方法 然后检查类之后是否没有更改 这样我以后就不太可能不小心破坏这个不变量了 这本身可能
  • 为什么Python中的元组可以使用reversed但没有__reversed__?

    在讨论中这个答案 https stackoverflow com questions 9449674 how to implement a persistent python list 9449852 9449852我们意识到元组没有 re
  • 更正应用程序的类路径,使其包含类 Log4J2LoggingSystem 和 PropertiesUtil 的兼容版本

    我正在将一个项目从 Spring Boot 2 6 1 迁移到 Spring Boot 3 0 2 但我遇到了 log4j 依赖项版本的问题 我已经修改了所有给我带来问题的依赖项 但我仍然无法解决问题 错误如下 Java HotSpot T
  • Flowplayer 播放一切

    我有一个flowplayer我正在使用它 下面有几张图片 当您点击这些图片时dialog是用这些图片的放大版本创建的 问题是flowplayer永远会在最上面dialog 我尝试过设置z index of the dialog高和flowp
  • 如何在 SwiftUI 中处理拖动到停靠栏图标上的操作?

    我已经设置了一个 SwiftUI 应用程序 它似乎接受拖放到停靠图标上的图像 但我无法弄清楚在应用程序代码中处理拖放图像的位置 如何处理将图像 或任何特定文件 拖放到 SwiftUI 应用程序的停靠图标上 背景 对于使用 NSApplica
  • 将枚举数据绑定到 WPF + MVVM 中的组合框

    我读了这个非常相关的问题在这里 https stackoverflow com questions 58743 databinding an enum property to a combobox in wpf 由于答案中的链接 这非常有帮
  • Golang:将文件附加到现有的 tar 存档中

    如何将文件附加到 Go 中现有的 tar 存档中 我没有看到任何明显的东西docs http golang org pkg archive tar 关于如何去做 我有一个已经创建的 tar 文件 我想在它关闭后向其中添加更多内容 EDIT
  • 为什么我不必在第二个 TableViewController 中释放 ManagedObjectContext

    我有两个显示 CoreData 对象的表视图控制器 一种是详细视图 带句子 一种是概述 带故事 选择一个故事 gt 查看句子 看来我过度释放了管理对象上下文 我最初在 dealloc 的两个 TableViewController 中发布了
  • 优化Python代码

    关于优化此 python 代码的任何提示寻找下一个回文 输入号码可以为1000000位 添加评论 usr bin python def inc lst lng this function first extract the left hal
  • 修复 Swift 3 中的警告“C-style for Statement is deprecated”

    我有更新Xcode到 7 3 现在我对用于创建随机字符串的函数发出警告 我尝试过改变for声明与for i in 0 lt len 然而 警告变成了错误 我怎样才能删除警告 static func randomStringWithLengt
  • Swift stdlib 工具错误

    我在使用 Xcode 8 1 和 Swift 3 编译时遇到此错误 Swift stdlib 工具错误 编译日志的末尾如下所示 Users Library Developer Xcode DerivedData Build Products
  • 让用户将记录器注入 Nodejs 模块的最佳实践

    我为 nodejs 编写了这个模块 可用于通过 sockjs 从任何地方向客户端分派事件 现在我想包括一些可配置的日志记录机制 目前 我将 winston 添加为依赖项 要求它作为每个类中的记录器并使用 logger error logge
  • 如何使用 MATLAB 和 JDBC 加速表检索?

    我正在使用 MATLAB 调用的 JDBC 访问 PostGreSQL 8 4 数据库 我感兴趣的表基本上由不同数据类型的各个列组成 他们是通过时间戳来选择的 由于我想检索大量数据 因此我正在寻找一种使请求比现在更快的方法 我现在正在做的事
  • 如何在 XAML 中使用 C# 中定义的画笔资源

    到目前为止我有这个
  • 新的 Conda 环境以及适用于 Jupyter Notebook 的最新 Python 版本

    由于 Python 版本变化很少 我总是忘记如何使用最新的 Python for Jupyter Notebook 创建新的 Conda 环境 所以我想下次将其列出来 从 StackOverflow 来看 有一些答案不再有效 下面是我在 S
  • 从 Apache Cordova 开始

    我刚刚下载了 Apache Cordova 似乎有特定于平台的版本 在将其移植到另一个平台之前 我是否必须为特定平台编写代码 是否可以创建一个多平台项目 我是否正确理解了我应该开始工作的方式 Apache Cordova 主页也是这么说的
  • 网络应用程序的照片存储[重复]

    这个问题在这里已经有答案了 可能的重复 用户镜像 数据库与文件系统存储 https stackoverflow com questions 585224 user images database vs filesystem storage
  • 构建 keras 模型

    我不明白这段代码中发生了什么 def construct model use imagenet True line 1 how do we keep all layers of this model model keras applicat