目前,我们还不能使用tf.keras.models.clone_model
对于子类化模型 API,而我们可以对于顺序 API 和函数式 API。从doc https://www.tensorflow.org/api_docs/python/tf/keras/models/clone_model,
model Instance of Model (could be a functional model or a Sequential model).
这是满足您需求的解决方法。如果我们需要复制经过训练的模型,我们可以在其中获得一些优化的参数,这是有意义的。因此,主要任务是我们需要通过复制现有模型来创建新模型。目前这种情况最方便的方法是get
训练有素的体重和set
到新创建的模型实例。首先构建一个模型,对其进行训练,然后获取权重矩阵并将其设置到新模型。
import tensorflow as tf
import numpy as np
class ModelSubClassing(tf.keras.Model):
def __init__(self, num_classes):
super(ModelSubClassing, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(32, 3, strides=2, activation="relu")
self.gap = tf.keras.layers.GlobalAveragePooling2D()
self.dense = tf.keras.layers.Dense(num_classes)
def call(self, input_tensor, training=False):
# forward pass: block 1
x = self.conv1(input_tensor)
x = self.gap(x)
return self.dense(x)
def build_graph(self, raw_shape):
x = tf.keras.layers.Input(shape=raw_shape)
return tf.keras.Model(inputs=[x], outputs=self.call(x))
# compile
sub_classing_model = ModelSubClassing(10)
sub_classing_model.compile(
loss = tf.keras.losses.CategoricalCrossentropy(),
metrics = tf.keras.metrics.CategoricalAccuracy(),
optimizer = tf.keras.optimizers.Adam())
# plot for debug
tf.keras.utils.plot_model(
sub_classing_model.build_graph(x_train.shape[1:]),
show_shapes=False,
show_dtype=False,
show_layer_names=True,
expand_nested=False,
dpi=96,
)
DataSet
(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()
# train set / data
x_train = np.expand_dims(x_train, axis=-1)
x_train = x_train.astype('float32') / 255
# train set / target
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
# fit
sub_classing_model.fit(x_train, y_train, batch_size=128, epochs=1)
# 469/469 [==============================] - 2s 2ms/step - loss: 8.2821
新模型/副本
对于子类模型,我们必须启动类对象。
sub_classing_model_copy = ModelSubClassing(10)
sub_classing_model_copy.build((x_train.shape))
sub_classing_model_copy.set_weights(sub_classing_model.get_weights()) # <- get and set wg
# plot for debug ; same as original plot
# but know, layer name is no longer same
# i.e. if, old: conv2d_40 , new/copy: conv2d_41
tf.keras.utils.plot_model(
sub_classing_model_copy.build_graph(x_train.shape[1:]),
show_shapes=False,
show_dtype=False,
show_layer_names=True,
expand_nested=False,
dpi=96,
)