Tensorflow Hub - 获取模型的输入形状和问题域?

2024-05-13

我正在使用最新版本的tensorflow hub,想知道如何获取有关模型的预期输入形状以及模型属于什么类型的集合的信息。 例如,有没有办法以这种方式在 Python 中加载模型后获取有关预期图像形状的信息?

model = hub.load("https://tfhub.dev/tensorflow/faster_rcnn/inception_resnet_v2_640x640/1")

还是这样?

model = hub.KerasLayer("https://tfhub.dev/tensorflow/faster_rcnn/inception_resnet_v2_640x640/1")

似乎两种情况下的模型对象都不知道预期的形状是什么 - 无论是图像高度/宽度还是批量大小。另一方面,可以通过以下方式找到此信息load_module_spec对于较旧的 TF 型号...

还有一个问题:是否有一种方法可以以编程方式获取模型属于哪个“问题域”的信息?可以查一下https://tfhub.dev/ https://tfhub.dev/,但是如果需要从模型对象本身或通过访问该信息怎么办?tensorflow_hub功能?

Thanks!


您可以通过访问模型的第一层并访问该层的 input_shape 属性来获取模型期望的输入形状

layers = model.layers
first_layer = layers[0] # usually the first layer is the input layer
print(first_layer.input_shape)

output:

[(None, 100, 100, 3)] # sample output

None -> 这指定了批量大小的大小,推断批量大小可以是您指定的任何值

(100, 100, 3) -> 高度、宽度和通道可能会有所不同,并且您给出的输入数据应该严格相同。

通过编程找到训练模型的域有点棘手,您可以使用 tensorflow.keras.util.plot_model 绘制模型的图,并可以从模型的架构推断域。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow Hub - 获取模型的输入形状和问题域? 的相关文章

随机推荐