获取中间层(Functional API)的输出并在SubClassed API中使用

2024-02-12

In the 喀拉斯文档 https://keras.io/getting_started/faq/,它说如果我们想选择中间层模型的输出(顺序和功能),我们需要做的如下:

model = ...  # create the original model

layer_name = 'my_layer'
intermediate_layer_model = keras.Model(inputs=model.input,
                                       outputs=model.get_layer(layer_name).output)
intermediate_output = intermediate_layer_model(data)

所以,这里我们有两个模型,intermediate_layer_model是其父模型的子模型。而且他们也是独立的。同样,如果我们得到中间层的输出特征图父模型(或基础模型)的,以及做一些操作有了它并从这个操作中得到一些输出特征图,然后我们也可以估算这个输出特征图返回到父模型。


input = tf.keras.Input(shape=(size,size,3))
model = tf.keras.applications.DenseNet121(input_tensor = input)

layer_name = "conv1_block1" # for example 
output_feat_maps = SomeOperationLayer()(model.get_layer(layer_name).output)  

# assume, they're able to add up
base = Add()([model.output, output_feat_maps])

# bind all 
imputed_model = tf.keras.Model(inputs=[model.input], outputs=base)

这样,我们就有了一个修改后的模型。使用函数式 API 非常容易。一切kerasimagenet 模型(大部分)是用函数式 API 编写的。在模型子类化API中,我们可以使用这些模型。我这里关心的是,如果我们需要这些功能性API模型内部的中间输出特征图怎么办call功能。

class Subclass(tf.keras.Model): 
    def __init__(self, dim):
         super(Subclass, self).__init__()
         self.dim = dim
         self.base = DenseNet121(input_shape=self.dim)

         # building new model with the desired output layer of base model 
         self.mid_layer_model = tf.keras.Model(self.base.inputs, 
                                     self.base.get_layer(layer_name).output)

    def call(self, inputs):
         # forward with base model 
         x = self.base(inputs)

         # forward with mid_layer_model 
         mid_feat = self.mid_layer_model(inputs)

         # do some op with it 
         mid_x = SomeOperationLayer()(mid_feat)
         
         # assume, they're able to add up
         out = tf.keras.layers.add([x, mid_x])

         return out 

问题是,我们在技术上两种型号以联合的方式。但与构建这样的模型不同,这里我们只需要基本模型前向方式的中间输出特征图(来自某些输入)并在其他地方使用它并获得一些输出。像这样

mid_x = SomeOperationLayer()(self.base.get_layer(layer_name).output)

但它给了ValueError: Graph disconnected。因此,目前,我们必须根据我们想要的中间层从基础模型构建一个新模型。在里面init我们定义或创建新的方法self.mid_layer_model模型给出了我们想要的输出特征图,如下所示:mid_feat = self.mid_layer_model(inputs)。接下来,我们采取mid_faet并进行一些操作并获得一些输出,最后将它们添加tf.keras.layers.add([x, mid_x])。因此,通过创建具有所需中间输出的新模型,但同时,我们重复相同的操作两次,即基本模型及其子集模型。也许我遗漏了一些明显的东西,请添加一些东西。是这样吗!或者我们可以采取一些策略。我在论坛问过here https://github.com/tensorflow/tensorflow/issues/47544https://github.com/tensorflow/tensorflow/issues/47544,还没有回复。


Update

这是一个工作示例。假设我们有一个像这样的自定义层

import tensorflow as tf
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.layers import Add
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten

class ConvBlock(tf.keras.layers.Layer):
    def __init__(self, kernel_num=32, kernel_size=(3,3), strides=(1,1), padding='same'):
        super(ConvBlock, self).__init__()
        # conv layer
        self.conv = tf.keras.layers.Conv2D(kernel_num, 
                        kernel_size=kernel_size, 
                        strides=strides, padding=padding)
        # batch norm layer
        self.bn = tf.keras.layers.BatchNormalization()

    def call(self, input_tensor, training=False):
        x = self.conv(input_tensor)
        x = self.bn(x, training=training)
        return tf.nn.relu(x)

我们想要将这一层归咎于 ImageNet 模型并构建一个像这样的模型

input = tf.keras.Input(shape=(32, 32, 3))
base = DenseNet121(weights=None, input_tensor = input)

# get output feature maps of at certain layer, ie. conv2_block1_0_relu
cb = ConvBlock()(base.get_layer("conv2_block1_0_relu").output)
flat = Flatten()(cb)
dense = Dense(1000)(flat)

# adding up
adding = Add()([base.output, dense])
model = tf.keras.Model(inputs=[base.input], outputs=adding)

from tensorflow.keras.utils import plot_model 
plot_model(model,
           show_shapes=True, show_dtype=True, 
           show_layer_names=True,expand_nested=False)

这里是从输入到层的计算conv2_block1_0_relu被计算一次。接下来,如果我们想将此函数式 API 转换为子类化 API,我们必须从基本模型的输入到层构建一个模型conv2_block1_0_relu第一的。喜欢

class ModelWithMidLayer(tf.keras.Model):
    def __init__(self, dim=(32, 32, 3)):
        super().__init__()
        self.dim = dim
        self.base = DenseNet121(input_shape=self.dim, weights=None)
        
        # building sub-model from self.base which gives 
        # desired output feature maps: ie. conv2_block1_0_relu
        self.mid_layer = tf.keras.Model(self.base.inputs,
                                        self.base.get_layer("conv2_block1_0_relu").output)
        
        self.flat = Flatten()
        self.dense = Dense(1000)
        self.add = Add()
        self.cb = ConvBlock()
    
    def call(self, x):
        # forward with base model
        bx = self.base(x)

        # forward with mid layer
        mx = self.mid_layer(x)

        # make same shape or do whatever
        mx = self.dense(self.flat(mx))
        
        # combine
        out = self.add([bx, mx])
        return out
    
    def build_graph(self):
        x = tf.keras.layers.Input(shape=(self.dim))
        return tf.keras.Model(inputs=[x], outputs=self.call(x))

mwml = ModelWithMidLayer()
plot_model(mwml.build_graph(),
           show_shapes=True, show_dtype=True, 
           show_layer_names=True,expand_nested=False)

Here model_1实际上是一个子模型DenseNet,这可能导致整个模型(ModelWithMidLayer) 计算相同的操作两次。如果这一观察是正确的,那么这就会引起我们的担忧。


我认为它可能很复杂,但实际上非常简单。我们只需要构建一个具有所需输出层的模型__init__方法并在中正常使用call method.

import tensorflow as tf
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.layers import Add
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten

class ConvBlock(tf.keras.layers.Layer):
    def __init__(self, kernel_num=32, kernel_size=(3,3), strides=(1,1), padding='same'):
        super(ConvBlock, self).__init__()
        # conv layer
        self.conv = tf.keras.layers.Conv2D(kernel_num, 
                        kernel_size=kernel_size, 
                        strides=strides, padding=padding)
        # batch norm layer
        self.bn = tf.keras.layers.BatchNormalization()

    def call(self, input_tensor, training=False):
        x = self.conv(input_tensor)
        x = self.bn(x, training=training)
        return tf.nn.relu(x)
class ModelWithMidLayer(tf.keras.Model):
    def __init__(self, dim=(32, 32, 3)):
        super().__init__()
        self.dim = dim
        self.base = DenseNet121(input_shape=self.dim, weights=None)
        
        # building sub-model from self.base which gives 
        # desired output feature maps: ie. conv2_block1_0_relu
        self.mid_layer = tf.keras.Model(
            inputs=[self.base.inputs],
            outputs=[
                     self.base.get_layer("conv2_block1_0_relu").output,
                     self.base.output])
        self.flat = Flatten()
        self.dense = Dense(1000)
        self.add = Add()
        self.cb = ConvBlock()
    
    def call(self, x):
        # forward with base model
        bx = self.mid_layer(x)[1] # output self.base.output
        # forward with mid layer
        mx = self.mid_layer(x)[0] # output base.get_layer("conv2_block1_0_relu").output
        # make same shape or do whatever
        mx = self.dense(self.flat(mx))
        # combine
        out = self.add([bx, mx])
        return out
    
    def build_graph(self):
        x = tf.keras.layers.Input(shape=(self.dim))
        return tf.keras.Model(inputs=[x], outputs=self.call(x))
mwml = ModelWithMidLayer()
tf.keras.utils.plot_model(mwml.build_graph(),
                          show_shapes=True, show_dtype=True, 
                          show_layer_names=True,expand_nested=False)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

获取中间层(Functional API)的输出并在SubClassed API中使用 的相关文章

  • django_openid_auth TypeError openid.yadis.manager.YadisServiceManager 对象不是 JSON 可序列化

    I used django openid auth在我的项目上 一段时间以来它运行得很好 但今天 我测试了该应用程序并遇到了这个异常 Environment Request Method GET Request URL http local
  • python 中的代表

    我实现了这个简短的示例来尝试演示一个简单的委托模式 我的问题是 这看起来我已经理解了委托吗 class Handler def init self parent None self parent parent def Handle self
  • 如何正确地将 MIDI 刻度转换为毫秒?

    我正在尝试将 MIDI 刻度 增量时间转换为毫秒 并且已经找到了一些有用的资源 MIDI Delta 时间刻度到秒 http www lastrayofhope co uk 2009 12 23 midi delta time ticks
  • python 模拟第三方模块

    我正在尝试测试一些处理推文的类 我使用 Sixohsix twitter 来处理 Twitter API 我有一个类充当 Twitter 类的外观 我的想法是模拟实际的 Sixohsix 类 通过随机生成新推文或从数据库检索它们来模拟推文的
  • 将数据帧行转换为字典

    我有像下面的示例数据这样的数据帧 我正在尝试将数据帧中的一行转换为类似于下面所需输出的字典 但是当我使用 to dict 时 我得到了索引和列值 有谁知道如何将行转换为像所需输出那样的字典 任何提示都非常感激 Sample data pri
  • Django 模型在模板中不可迭代

    我试图迭代模型以获取列表中的第一个图像 但它给了我错误 即模型不可迭代 以下是我的模型和模板的代码 我只需要获取与单个产品相关的列表中的第一个图像 模型 py class Product models Model title models
  • Argparse nargs="+" 正在吃位置参数

    这是我的解析器配置的一小部分 parser add argument infile help The file to be imported type argparse FileType r default sys stdin parser
  • 如何在 pytest 中将单元测试和集成测试分开

    根据维基百科 https en wikipedia org wiki Unit testing Description和各种articles https techbeacon com devops 6 best practices inte
  • Pandas 中允许重复列

    我将一个大的 CSV 包含股票财务数据 文件分割成更小的块 CSV 文件的格式不同 像 Excel 数据透视表之类的东西 第一列的前几行包含一些标题 公司名称 ID 等在以下列中重复 因为一家公司有多个属性 而不是一家公司只有一栏 在前几行
  • 切片 Dataframe 时出现 KeyError

    我的代码如下所示 d pd read csv Collector Output csv df pd DataFrame data d dfa df copy dfa dfa rename columns OBJECTID Object ID
  • 使用鼻子获取设置中当前测试的名称

    我目前正在使用鼻子编写一些功能测试 我正在测试的库操作目录结构 为了获得可重现的结果 我存储了一个测试目录结构的模板 并在执行测试之前创建该模板的副本 我在测试中执行此操作 setup功能 这确保了我在测试开始时始终具有明确定义的状态 现在
  • 在 Pandas 中使用正则表达式的多种模式

    我是Python编程的初学者 我正在探索正则表达式 我正在尝试从 描述 列中提取一个单词 数据库名称 我无法给出多个正则表达式模式 请参阅下面的描述和代码 描述 Summary AD1 Low free DATA space in data
  • 创建嵌套字典单行

    您好 我有三个列表 我想使用一行创建一个三级嵌套字典 i e l1 a b l2 1 2 3 l3 d e 我想创建以下嵌套字典 nd a 1 d 0 e 0 2 d 0 e 0 3 d 0 e 0 b a 1 d 0 e 0 2 d 0
  • 使用队列从多个输入文件中统一采样

    我的数据集中的每个类都有一个序列化文件 我想使用队列来加载每个文件 然后将它们放入 RandomShuffleQueue 中 这样我就可以从每个类中获得随机的示例组合 我认为这段代码会起作用 在此示例中 每个文件有 10 个示例 filen
  • Tkinter - 浮动窗口 - 调整大小

    灵感来自this https stackoverflow com a 22424245 13629335问题 我想为我的根窗口编写自己的调整大小函数 但我刚刚注意到我的代码显示了一些性能问题 如果你快速调整它的大小 你会发现窗口没有像我希望
  • Ubuntu 上的 Python 2.7

    我是 Python 新手 正在 Linux 机器 Ubuntu 10 10 上工作 它正在运行 python 2 6 但我想运行 2 7 因为它有我想使用的功能 有人敦促我不要安装 2 7 并将其设置为我的默认 python 我的问题是 如
  • 在Python中按属性获取对象列表中的索引

    我有具有属性 id 的对象列表 我想找到具有特定 id 的对象的索引 我写了这样的东西 index 1 for i in range len my list if my list i id specific id index i break
  • 如何读取Python字节码?

    我很难理解 Python 的字节码及其dis module import dis def func x 1 dis dis func 上述代码在解释器中输入时会产生以下输出 0 LOAD CONST 1 1 3 STORE FAST 0 x
  • 检查字典键是否有空值

    我有以下字典 dict1 city name yass region zipcode phone address tehsil planet mars 我正在尝试创建一个基于 dict1 的新字典 但是 它不会包含带有空字符串的键 它不会包
  • 从 Twitter API 2.0 获取 user.fields 时出现问题

    我想从 Twitter API 2 0 端点加载推文 并尝试获取标准字段 作者 文本 和一些扩展字段 尤其是 用户 字段 端点和参数的定义工作没有错误 在生成的 json 中 我只找到标准字段 但没有找到所需的 user fields 用户

随机推荐

  • 如何在 Verilog 中综合 While 循环?

    我尝试设计一个 Booth 乘法器 它在所有编译器中运行良好 包括 Modelsim Verilogger Extreme Aldec Active Hdl 和 Xilinx Isim 我知道模拟和综合是两个不同的过程 而且只有少数Veri
  • 使用 SELECT 执行 INSERT 插入多条记录

    在下图中 DodgyOldTable 和 MainTable 之间存在 1 1 关系 表 Option 包含 OptionDesc 字段中带有 OptionVal1 OptionVal2 和 OptionVal3 的记录 我需要使用 Dod
  • PHP 中未终止的实体引用

    这是我的代码
  • 什么是名称查找机制?

    我想知道C 名称查找机制是什么 名称查找是识别名称含义的过程 名称查找有两个目的 消除代码解析的歧义 确定代码的确切含 义 例如 如果你有这个代码 T a 这取决于是否T是否是一个类型 如果是一个类型 它将是一个声明a 如果它不是类型 则将
  • 我必须在 Next.js 项目中使用express吗?

    我正在制作一个网站Next js Next js提供SSR and dynamic routing 我必须使用express 如果是这样 为什么我必须使用它 具有什么样的特点express有用但未提供的Next js I think nex
  • 提供满足esm、commonjs和bundlers的模块、主要和浏览器字段

    我有许多已发布的 npm 包 我已将它们升级为提供 commonjs 和 esm 构建 有些包可能同时适用于节点和浏览器 所有使用 webpack 或 rollup 编译的包 所有内容都用打字稿编写并转换为dist目录 我创建了一个comm
  • iOS 上的 html5 录音

    我正在尝试访问 iOS 上的麦克风以捕获用户输入
  • Azure SQL 中所有用户的列表

    如何列出可以连接到我的 sql server 数据库的所有用户 现在可以找到任何 sql 命令 我尝试了互联网上的一些链接 但没有一个有效 我尝试过的一些命令 SELECT FROM sys sql logins SELECT FROM s
  • 我应该如何配置 grunt-usemin 来使用相对路径

    我有一个由 yeoman generator 支持的 grunt 项目 我是基于generator webapp https github com yeoman generator webapp 如果有帮助 您可以在GitHub https
  • MVVM 和 StructureMap 的使用

    我的 MVVM 应用程序中有大量父级详细信息 ViewModel 像这样的事情 SchoolsViewModel SchoolViewModel LessonViewModel PupilsViewModel PupilViewModel
  • Bitset 作为函数的返回值

    我想要一个接口 其函数返回一个位集 class IMyInterface public virtual std bitset lt 100 gt GetBits 0 问题是我不想强制大小bitset 所以我想我必须使用boost dynam
  • Ruby 2.0 字节码导出/导入

    我一直在读关于红宝石 2 0 新功能 http www rubyinside com ruby 2 0 implementation work begins what is ruby 2 0 and whats new 5515 html
  • 使用 JavaScript 创建 Base64 编码图像

    由于图像是数据 我们可以将代码编写为 img src alt Red dot 现在我的观点是 我们可以使用 javascript 创建 base64 数据吗 有什么框架吗 我的实际要求是我有一个像 Cow 这样的字符串 我希望它作为图像 注
  • Java 最终抽象类

    我有一个非常简单的问题 我想要一个 Java 类 它提供一个公共静态方法 该方法可以执行某些操作 这只是为了封装目的 将所有重要的内容都放在一个单独的类中 这个类既不应该被实例化 也不应该被扩展 这让我写道 final abstract c
  • 打开软键盘时,DialogFragment 始终会调整大小

    我在全屏显示的自定义 DialogFragment 方面遇到一些问题 该对话框包含可滚动的内容并具有自动完成文本视图 最初 对话框在顶部显示有边距 以编程方式设置为布局内容顶部的透明视图 一旦 autocompletetextview 获得
  • Android-如何区分 Galaxy S-3 和 Galaxy S-4 布局?

    我在区分三星 Galaxy s4 和三星 Galaxy s3 的布局文件夹时遇到问题 我尝试过layout sw360dp layout sw360dp xxhdpi layout sw360dp xhdpi等 一直以来 galaxy s4
  • 是否可以在 iCal.net 上使用 UTC 偏移量代替时区名称?

    我的应用程序将 UTC 偏移量存储在用户配置文件上 例如 03 00 并且正如 iCal net Wiki 中的此页面提到的那样 我似乎只能使用时区来分配给事件 https github com rianjs ical net wiki W
  • 在 Chrome 中单击并拖动光标

    我正在开发一个网络应用程序 我需要覆盖一些默认光标 在 Chrome 中 当我单击然后拖动它时 它总是将光标更改为文本选择 我似乎找不到任何方法来覆盖它 我正在使用jquery 通常的 document css cursor default
  • fifo - 循环读取

    我想用os mkfifo http docs python org 2 library os html os mkfifo用于程序之间的简单通信 我在循环读取 fifo 时遇到问题 考虑这个玩具示例 其中我有一个读取器和一个写入器使用 fi
  • 获取中间层(Functional API)的输出并在SubClassed API中使用

    In the 喀拉斯文档 https keras io getting started faq 它说如果我们想选择中间层模型的输出 顺序和功能 我们需要做的如下 model create the original model layer n