带有 Tensorflow 后端的 Keras 的 K.function 方法是否适用于网络层?

2024-04-05

我最近开始使用 Keras 构建神经网络。我构建了一个简单的 CNN 来对 MNIST 数据集进行分类。在学习我使用的模型之前K.set_image_dim_ordering('th')为了绘制卷积层权重。现在我正在尝试用以下方法可视化卷积层输出K.function方法,但我不断收到错误。

这是我现在想做的:

input_image = X_train[2:3,:,:,:]

output_layer = model.layers[1].output
input_layer = model.layers[0].input

output_fn = K.function(input_layer, output_layer)

output_image = output_fn.predict(input_image)
print(output_image.shape)

output_image = np.rollaxis(np.rollaxis(output_image, 3, 1), 3, 1)
print(output_image.shape)

fig = plt.figure()
for i in range(32):
    ax = fig.add_subplot(4,8,i+1)
    im = ax.imshow(output_image[0,:,:,i], cmap="Greys")
    plt.xticks(np.array([]))
    plt.yticks(np.array([]))
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([1, 0.1, 0.05 ,0.8])
fig.colorbar(im, cax = cbar_ax)
plt.tight_layout()

plt.show()

这就是我得到的:

  File "/home/kinshiryuu/anaconda3/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py", line 1621, in function
return Function(inputs, outputs, updates=updates)

  File "/home/kinshiryuu/anaconda3/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py", line 1569, in __init__
raise TypeError('`inputs` to a TensorFlow backend function '

TypeError: `inputs` to a TensorFlow backend function should be a list or tuple.

您应该进行以下更改:

output_fn = K.function([input_layer], [output_layer])
output_image = output_fn([input_image])

K.function将输入和输出张量作为列表,以便您可以创建从多个输入到多个输出的函数。在您的情况下,一个输入一个输出..但您仍然需要将它们作为列表传递。

Next K.function返回一个张量函数,而不是您可以使用的模型对象predict()。正确的使用方法就是作为函数调用

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

带有 Tensorflow 后端的 Keras 的 K.function 方法是否适用于网络层? 的相关文章

随机推荐