首先,是的,您应该使用新的 SavedModel 格式,因为它是 TF 团队今后将支持的格式,并且也可以与 Keras 配合使用。您可以向模型添加一个额外的端点,该端点返回一个带有 XML 数据字符串的常量张量(如您所提到的)。
这很好,因为它是封闭的——底层的保存模型格式并不重要,因为您的元数据保存在计算图中本身。
请看这个问题的答案:使用自定义签名定义保存 TF2 keras 模型 https://stackoverflow.com/questions/56659949/saving-a-tf2-keras-model-with-custom-signature-defs。这个答案并不能让你 100% 理解 Keras,因为它不能与 tf.keras.models.load 函数很好地互操作,因为它们将其包装在一个tf.Module
。幸运的是,使用tf.keras.Model
如果添加 tf.function 装饰器,在 TF2 中也能正常工作:
class MyModel(tf.keras.Model):
def __init__(self, metadata, **kwargs):
super(MyModel, self).__init__(**kwargs)
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.metadata = tf.constant(metadata)
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
@tf.function(input_signature=[])
def get_metadata(self):
return self.metadata
model = MyModel('metadata_test')
input_arr = tf.random.uniform((5, 5, 1)) # This call is needed so Keras knows its input shape. You could define manually too
outputs = model(input_arr)
然后您可以保存并加载模型,如下所示:
tf.keras.models.save_model(model, 'test_model_keras')
model_loaded = tf.keras.models.load_model('test_model_keras')
最后使用model_loaded.get_metadata()
检索您的常量元数据张量。