你可以写一个自定义的Callback
并在每次纪元结束时使用它。我展示它是为了打印权重,但您可以将它用作自定义损失的一部分。
class CustomCallback(keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
rand_int = tf.random.uniform((), 0, 2, dtype=tf.int32)
print(rand_int)
model.fit(X, y epochs = 10, batch_size = 20, validation_split=0.1, callbacks=[CustomCallback()])
更多细节here https://www.tensorflow.org/guide/keras/custom_callback.
例如,这是一个用于打印的虚拟代码weights and biases
of layer[1]
每个纪元之后。您可以按照您喜欢的方式设置该功能。
from tensorflow.keras import layers, Model, callbacks
class CustomCallback(callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print(' ')
print(' ')
print(model.layers[1].get_weights())
X, y = np.random.random((10,5)), np.random.random((10,))
inp = layers.Input((5,))
x = layers.Dense(3)(inp)
out = layers.Dense(1)(x)
model = Model(inp, out)
model.compile(loss='MAE',metrics=['accuracy'])
model.fit(X,y,callbacks=[CustomCallback()], epochs=3)
Epoch 1/3
1/1 [==============================] - ETA: 0s - loss: 0.2346 - accuracy: 0.0000e+00
[array([[ 0.16518219, -0.44628695, -0.07702655],
[-0.1993848 , 0.03855793, -0.62964785],
[ 0.5592851 , -0.28281152, -0.23358124],
[ 0.05242977, 0.4023881 , -0.19522922],
[ 0.07936202, -0.40436065, 0.10003945]], dtype=float32), array([ 0.01530731, -0.01565045, -0.01581042], dtype=float32)]
1/1 [==============================] - 0s 2ms/step - loss: 0.2346 - accuracy: 0.0000e+00
Epoch 2/3
1/1 [==============================] - ETA: 0s - loss: 0.2337 - accuracy: 0.0000e+00
[array([[ 0.16814367, -0.4492649 , -0.08000461],
[-0.19710523, 0.03622784, -0.6319782 ],
[ 0.55797213, -0.28144714, -0.23221655],
[ 0.05509637, 0.3996864 , -0.19793113],
[ 0.07731982, -0.40226308, 0.10213734]], dtype=float32), array([ 0.01846951, -0.01881272, -0.01897269], dtype=float32)]
1/1 [==============================] - 0s 7ms/step - loss: 0.2337 - accuracy: 0.0000e+00
Epoch 3/3
1/1 [==============================] - ETA: 0s - loss: 0.2322 - accuracy: 0.0000e+00
[array([[ 0.16706704, -0.448164 , -0.07889817],
[-0.19894598, 0.0381193 , -0.63007975],
[ 0.5558067 , -0.27921563, -0.22997847],
[ 0.05663134, 0.3981127 , -0.19951159],
[ 0.07536169, -0.400249 , 0.10415838]], dtype=float32), array([ 0.01846951, -0.01881272, -0.01897269], dtype=float32)]
1/1 [==============================] - 0s 2ms/step - loss: 0.2322 - accuracy: 0.0000e+00