我训练了一个 CNN 并相应地保存了它:
model = Sequential()
model.add(Flatten(input_shape=train_data.shape[1:]))
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='rmsprop',
loss='binary_crossentropy', metrics=['accuracy'])
model.fit(train_data, train_labels,
epochs=epochs,
batch_size=batch_size,
validation_data=(validation_data, validation_labels))
model.save('full_model.h5')
我现在尝试使用以下命令在另一个 python 脚本中加载模型:
model = tf.keras.models.load_model('full_model.h5')
并收到以下错误:
Traceback (most recent call last):
File "/media/spt/Data/tensorflow_server/get_model.py", line 12, in <module>
model = tf.keras.models.load_model('full_model.h5')
File "/home/spt/.conda/envs/dev_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/saving.py", line 229, in load_model
model = model_from_config(model_config, custom_objects=custom_objects)
File "/home/spt/.conda/envs/dev_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/saving.py", line 306, in model_from_config
return deserialize(config, custom_objects=custom_objects)
File "/home/spt/.conda/envs/dev_env/lib/python3.6/site-packages/tensorflow/python/keras/layers/serialization.py", line 64, in deserialize
printable_module_name='layer')
File "/home/spt/.conda/envs/dev_env/lib/python3.6/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 173, in deserialize_keras_object
list(custom_objects.items())))
File "/home/spt/.conda/envs/dev_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/sequential.py", line 286, in from_config
layer = layer_module.deserialize(conf, custom_objects=custom_objects)
File "/home/spt/.conda/envs/dev_env/lib/python3.6/site-packages/tensorflow/python/keras/layers/serialization.py", line 64, in deserialize
printable_module_name='layer')
File "/home/spt/.conda/envs/dev_env/lib/python3.6/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 193, in deserialize_keras_object
function_name)
ValueError: Unknown layer:name
我遇到了多个描述相同/相似问题的网站,例如堆栈溢出 https://stackoverflow.com/questions/53180589/keras-valueerror-unknown-layername-when-trying-to-load-model-to-another-platf, github https://github.com/keras-team/keras/issues/11617。通常,问题是 Keras 版本过时。但就我而言,所有 Keras 相关包都是最新的(所有 keras 相关包的 conda 列表的输出):
keras-applications 1.0.6 py36_0
keras-base 2.2.4 py36_0
keras-gpu 2.2.4 0
keras-preprocessing 1.0.5 py36_0
谁能建议我如何解决/解决这个问题?