Tensorflow 2 Hub:如何获取中间层的输出?

2023-12-31

我正在尝试实施以下网络Fots https://arxiv.org/pdf/1801.01671.pdf使用新的tensorflow 2进行文本检测。作者使用resnet作为其网络的骨干,所以我的第一个想法是使用tensoflow hub resnet来加载预训练的网络。但问题是我找不到打印模块摘要的方法,该模块是从 tfhub 加载的?

有什么方法可以从 tf-hub 中查看加载模块的层数吗? 谢谢


Update

不幸的是,resnet 不适用于 tf2-hub,所以我决定使用内置的 keras 实现 resent,至少在有 hub 实现之前是这样。它的。

以下是我如何使用 tf2.keras.applications 获取 resnet 的中间层:

import numpy as np
import tensorflow as tf
from tensorflow import keras

layers_out = ["activation_9", "activation_21", "activation_39", "activation_48"]

imgs = np.random.randn(2, 640, 640, 3).astype(np.float32)
model = keras.applications.resnet50.ResNet50(input_shape=(640, 640, 3), include_top=False)
intermid_outputs= [model.get_layer(layer_name).output for layer_name in layers_out]
shared_conds = keras.Model(inputs=model.input, outputs=intermid_outputs)
Y = conv_shared(imgs)
shapes = [y.shape for y in Y]
print(shapes)

您可以执行以下操作来检查中间输出:

resnet = hub.Module("https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/3")
outputs = resnet(np.random.rand(1,224,224,3), signature="image_feature_vector", as_dict=True)
for intermediate_output in outputs.keys():
    print(intermediate_output)

然后,如果您想将 hub 模块的中间层链接到图表的其余部分,您可以执行以下操作:

resnet = hub.Module("https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/3")
features = resnet(images, signature="image_feature_vector", as_dict=True)["resnet_v2_50/block4"]
flatten = tf.reshape(features, (-1, features.shape[3]))

假设我们要从 ResNet 的最后一个块中提取特征。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow 2 Hub:如何获取中间层的输出? 的相关文章

随机推荐