如何将具有像 ResNet 这样的非序列架构的 Keras 模型拆分为子模型?

2024-02-01

我的模型是 resnet-152,我想将其切成两个子模型,问题是第二个子模型,我不知道如何构建从中间层到输出的模型

我尝试了这段代码这个回应 https://stackoverflow.com/questions/52800025/keras-give-input-to-intermediate-layer-and-get-final-output/56140169#56140169它对我不起作用,这是我的代码:

def getLayerIndexByName(model, layername):
    for idx, layer in enumerate(model.layers):
        if layer.name == layername:
            return idx

idx = getLayerIndexByName(resnet, 'res3a_branch2a')

input_shape = resnet.layers[idx].get_input_shape_at(0) # which is here in my case (None, 55, 55, 256)

layer_input = Input(shape=input_shape[1:]) # as keras will add the batch shape

# create the new nodes for each layer in the path
x = layer_input
for layer in resnet.layers[idx:]:
    x = layer(x)

# create the model
new_model = Model(layer_input, x)

我收到此错误:

ValueError: Input 0 is incompatible with layer res3a_branch1: expected axis -1 of input shape to have value 256 but got shape (None, 28, 28, 512).

我也尝试过这个功能:

def split(model, start, end):
    confs = model.get_config()
    kept_layers = set()
    for i, l in enumerate(confs['layers']):
        if i == 0:
            confs['layers'][0]['config']['batch_input_shape'] = model.layers[start].input_shape
            if i != start:
                confs['layers'][0]['name'] += str(random.randint(0, 100000000)) # rename the input layer to avoid conflicts on merge
                confs['layers'][0]['config']['name'] = confs['layers'][0]['name']
        elif i < start or i > end:
            continue
        kept_layers.add(l['name'])
    # filter layers
    layers = [l for l in confs['layers'] if l['name'] in kept_layers]
    layers[1]['inbound_nodes'][0][0][0] = layers[0]['name']
    # set conf
    confs['layers'] = layers
    confs['input_layers'][0][0] = layers[0]['name']
    confs['output_layers'][0][0] = layers[-1]['name']
    # create new model
    submodel = Model.from_config(confs)
    for l in submodel.layers:
        orig_l = model.get_layer(l.name)
        if orig_l is not None:
            l.set_weights(orig_l.get_weights())
    return submodel

我收到此错误:

ValueError: Unknown layer: Scale

因为我的 resnet152 包含一个 Scale 层。

这是一个工作版本:

import resnet   # pip install resnet
from keras.models import Model
from keras.layers import Input

def getLayerIndexByName(model, layername):
    for idx, layer in enumerate(model.layers):
        if layer.name == layername:
            return idx


resnet = resnet.ResNet152(weights='imagenet')

idx = getLayerIndexByName(resnet, 'res3a_branch2a')

model1 = Model(inputs=resnet.input, outputs=resnet.get_layer('res3a_branch2a').output)

input_shape = resnet.layers[idx].get_input_shape_at(0) # get the input shape of desired layer
print(input_shape[1:])
layer_input = Input(shape=input_shape[1:]) # a new input tensor to be able to feed the desired layer

# create the new nodes for each layer in the path
x = layer_input
for layer in resnet.layers[idx:]:
    x = layer(x)

# create the model
model2 = Model(layer_input, x)

model2.summary()

这是错误:

ValueError: Input 0 is incompatible with layer res3a_branch1: expected axis -1 of input shape to have value 256 but got shape (None, 28, 28, 512)

正如我在评论部分提到的,由于 ResNet 模型没有线性架构(即它具有跳过连接,并且一个层可能连接到多个层),因此您不能简单地逐层浏览模型的层一个循环,并在循环中前一层的输出上应用一个层(即与具有线性架构的模型不同,这个方法有效 https://stackoverflow.com/a/52814386/2099607).

因此,您需要找到各层的连通性并遍历该连通性图,以便能够构建原始模型的子模型。目前,我想到了这个解决方案:

  1. 指定子模型的最后一层。
  2. 从该层开始,找到与其连接的所有层。
  3. 获取这些连接层的输出。
  4. 将最后一层应用于收集的输出。

显然,步骤#3意味着递归:为了获得连接层(即X)的输出,我们首先需要找到它们的连接层(即Y),获取它们的输出(即Y的输出),然后将它们应用到这些输出上(即在 Y 的输出上应用 X)。此外,要找到连接层,您需要了解一些 Keras 的内部结构,这已在这个答案 https://stackoverflow.com/a/53944525/2099607。所以我们提出了这个解决方案:

from keras.applications.resnet50 import ResNet50
from keras import models
from keras import layers

resnet = ResNet50()

# this is the split point, i.e. the starting layer in our sub-model
starting_layer_name = 'activation_46'

# create a new input layer for our sub-model we want to construct
new_input = layers.Input(batch_shape=resnet.get_layer(starting_layer_name).get_input_shape_at(0))

layer_outputs = {}
def get_output_of_layer(layer):
    # if we have already applied this layer on its input(s) tensors,
    # just return its already computed output
    if layer.name in layer_outputs:
        return layer_outputs[layer.name]

    # if this is the starting layer, then apply it on the input tensor
    if layer.name == starting_layer_name:
        out = layer(new_input)
        layer_outputs[layer.name] = out
        return out

    # find all the connected layers which this layer
    # consumes their output
    prev_layers = []
    for node in layer._inbound_nodes:
        prev_layers.extend(node.inbound_layers)

    # get the output of connected layers
    pl_outs = []
    for pl in prev_layers:
        pl_outs.extend([get_output_of_layer(pl)])

    # apply this layer on the collected outputs
    out = layer(pl_outs[0] if len(pl_outs) == 1 else pl_outs)
    layer_outputs[layer.name] = out
    return out

# note that we start from the last layer of our desired sub-model.
# this layer could be any layer of the original model as long as it is
# reachable from the starting layer
new_output = get_output_of_layer(resnet.layers[-1])

# create the sub-model
model = models.Model(new_input, new_output)

重要笔记:

  1. 该解决方案假设原始模型中的每个层仅使用一次,即它不适用于暹罗网络,其中一个层可以共享,因此可能在不同的输入张量上应用多次。

  2. 如果您想将模型正确分割为多个子模型,那么仅使用这些层作为分割点是有意义的(例如由starting_layer_name在上面的代码中),它们不在分支中(例如,在 ResNet 中,合并层之后的激活层是一个不错的选择,但是res3a_branch2a您选择的不是一个好的选择,因为它位于分支中)。为了更好地了解模型的原始架构,您始终可以使用以下命令绘制其图表plot_model()实用功能:

    from keras.applications.resnet50 import ResNet50
    from keras.utils import plot_model
    
    resnet = ResNet50()
    plot_model(model, to_file='resnet_model.png')
    
  3. 由于在构建子模型后会创建新节点,因此不要尝试构建另一个子模型有重叠的(即,如果它没有重叠,那就可以了!)与之前的子模型在上面代码的同一运行中;否则,您可能会遇到错误。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何将具有像 ResNet 这样的非序列架构的 Keras 模型拆分为子模型? 的相关文章

  • 以 str.format 切片字符串

    我想实现以下目标str format x y 1234 5678 print str x 2 str y 2 我能够做到这一点的唯一方法是 print 0 1 format str x 2 str y 2 现在 这是一个例子 我真正拥有的是
  • 如何使用playsound模块停止音频?

    如何在Python代码中通过playaudio模块停止音频播放 我播放过音乐 但我无法停止音乐 我怎样才能阻止它 playsound playsound name of file 您可以使用多处理模块将声音作为后台进程播放 然后随时终止它
  • 当语料库有100亿个独特的DNA序列时,如何使用BK树实现快速模糊搜索引擎?

    我正在尝试使用BK tree https news ycombinator com item id 14022424python 中的数据结构 用于存储约 100 亿个条目的语料库 1e10 以实现快速模糊搜索引擎 一旦我添加超过 1000
  • 如何将人物传奇带到前台?

    我有一系列子图 其中每个子图都有一个图例 我想在每个子图之外与相邻子图重叠 问题在于图例位于其自己的图的 顶部 但位于相邻图的下方 Legend 不将 zorder 作为参数 所以我不知道如何解决这个问题 这是我使用过的代码 import
  • 如何在 Python 2.4 CSV 阅读器中禁用引用?

    我正在编写一个 Python 实用程序 需要解析一个我无法控制的大型且定期更新的 CSV 文件 该实用程序必须在仅提供 Python 2 4 的服务器上运行 CSV 文件根本不引用字段值 但Python 2 4版本的csv库 http ww
  • Huggingface 变形金刚模块未被 anaconda 识别

    我正在使用 Anaconda python 3 7 Windows 10 我尝试通过安装变压器https huggingface co transformers https huggingface co transformers 在我的环境
  • 将文件标记为从 Python 中删除?

    在我的一个脚本中 我需要删除当时可能正在使用的文件 我知道我无法删除正在使用的文件 直到它不再使用为止 但我也知道我可以将该文件标记为由操作系统 Windows XP 删除 我将如何在 Python 中做到这一点 以及另一个不依赖于 pyw
  • Python,将字典存储在数据库中

    在数据库中存储和检索 python 字典的最佳方法是什么 如果您对使用传统 SQL 数据库 例如 MySQL 不是特别感兴趣 您可以研究非结构化文档数据库 其中文档自然映射到 python 字典 例如MongoDB http www mon
  • self.__dict__.update(**kwargs) 的风格是好是坏?

    在 Python 中 假设我有一些类 Circle 它继承自 Shape Shape 需要 x 和 y 坐标 此外 Circle 需要半径 我希望能够通过执行类似的操作来初始化 Circle c Circle x 1 y 5 r 3 Cir
  • Pygooglevoice登录错误

    另一个人问了这个问题 但没有回复 所以我再问一遍 我正在尝试使用 pygooglevoice API 但是当我运行 SMS py 示例脚本时 它给了我一个登录错误 我已经安装了 Enthought python 我想也许我还需要安装其他东西
  • 自定义 Keras 损失函数中的 conv2d

    我正在尝试基于两个图像的拉普拉斯算子在带有 TF 后端的 Keras 中实现自定义损失函数 def blur loss y true y pred weighting of blur loss alpha 1 mae losses mean
  • 从网站上抓取数字和详细信息的数据

    我想从网站上抓取联系电话以及快递服务的相应详细信息 我无法从所有快递服务中获取联系电话和其他详细信息 例如姓名地址和评级 我分析的数据位于脚本标签中 请提出修复此问题的建议 import requests import pandas as
  • t /= d 是什么意思? Python 和错误

    t current time b begInnIng value c change In value d duration def easeOutQuad swing function x t b c d alert jQuery easi
  • pandas-更改重采样时间序列的开始和结束日期

    我有一个时间序列 我将其重新采样到这个数据框中df 我的数据是从6月6日到6月28日 它希望将数据从6月1日延长到6月30日 计数列仅在较长时间内具有 0 值 而我的实际值是从 6 日到 28 日 Out 123 count Timesta
  • CryptoJS 和 Pycrypto 一起工作

    我正在使用 CryptoJS v 2 3 加密 Web 应用程序中的字符串 并且需要在服务器上使用 Python 对其进行解密 因此我使用 PyCrypto 我觉得我错过了一些东西 因为我无法让它工作 这是JS Crypto AES enc
  • 枚举上的 random.choice

    我想用random choice on an Enum I tried class Foo Enum a 0 b 1 c 2 bar random choice Foo 但是这段代码失败了KeyError 我怎样才能随机选择一个成员Enum
  • 执行许多插入重复键更新错误:未使用所有参数

    所以我一直在尝试使用 python 2 7 15 使用 mysql connector 执行此查询 但由于某种原因 它似乎不起作用并且总是返回错误 并非所有参数都被使用 表更新有一个主键 即 ID 这是我尝试运行此 SQL 的查询 sql
  • PyMC3 和 Theano - 导入 pymc3 后,有效的 Theano 代码停止工作

    一些简单的 theano 代码可以完美运行 当我导入 pymc3 时停止工作 这里有一些片段可以重现错误 Initial Theano Code this works import theano tensor as tsr x tsr ds
  • 混合两个列表的Pythonic方法[重复]

    这个问题在这里已经有答案了 我有两个长度为 n 和 n 1 的列表 a 1 a 2 a n b 1 b 2 b n 1 我想要一个函数作为结果给出一个列表 其中包含两个中的替代元素 即 b 1 a 1 b n a n b n 1 以下方法有
  • 将 .parquet 编码为 io.Bytes

    目标 将 Parquet 文件上传到 MinIO 这需要将文件转换为字节 我已经能够做到这一点了 csv json and txt bytes data to csv encode utf 8 bytes json dumps self d

随机推荐