你可以试试后台功能K.in_train_phase()
,这是由Dropout
and BatchNormalization
层在训练和验证中实现不同的行为。
def custom_loss(y_true, y_pred):
weighted_loss = ... # your implementation of weighted crossentropy loss
unweighted_loss = K.sparse_categorical_crossentropy(y_true, y_pred)
return K.in_train_phase(weighted_loss, unweighted_loss)
第一个参数K.in_train_phase()
是训练阶段使用的张量,第二个是测试阶段使用的张量。
例如,如果我们设置weighted_loss
为0(只是为了验证效果K.in_train_phase()
功能):
def custom_loss(y_true, y_pred):
weighted_loss = 0 * K.sparse_categorical_crossentropy(y_true, y_pred)
unweighted_loss = K.sparse_categorical_crossentropy(y_true, y_pred)
return K.in_train_phase(weighted_loss, unweighted_loss)
model = Sequential([Dense(100, activation='relu', input_shape=(100,)), Dense(1000, activation='softmax')])
model.compile(optimizer='adam', loss=custom_loss)
model.outputs[0]._uses_learning_phase = True # required if no dropout or batch norm in the model
X = np.random.rand(1000, 100)
y = np.random.randint(1000, size=1000)
model.fit(X, y, validation_split=0.1)
Epoch 1/10
900/900 [==============================] - 1s 868us/step - loss: 0.0000e+00 - val_loss: 6.9438
可以看到,训练阶段的loss确实是1乘以0。
请注意,如果您的模型中没有 dropout 或批量归一化,您需要手动“打开”_uses_learning_phase
布尔开关,否则K.in_train_phase()
默认情况下不会有任何影响。