我训练了一个使用三重态损失的连体神经网络。这很痛苦,但我想我做到了。然而,我很难理解如何用这个模型进行评估。
The SNN:
def triplet_loss(y_true, y_pred):
margin = K.constant(1)
return K.mean(K.maximum(K.constant(0), K.square(y_pred[:,0]) - 0.5*(K.square(y_pred[:,1])+K.square(y_pred[:,2])) + margin))
def euclidean_distance(vects):
x, y = vects
return K.sqrt(K.maximum(K.sum(K.square(x - y), axis=1, keepdims=True), K.epsilon()))
anchor_input = Input((max_len, ), name='anchor_input')
positive_input = Input((max_len, ), name='positive_input')
negative_input = Input((max_len, ), name='negative_input')
Shared_DNN = create_base_network(embedding_dim = EMBEDDING_DIM, max_len=MAX_LEN, embed_matrix=embed_matrix)
encoded_anchor = Shared_DNN(anchor_input)
encoded_positive = Shared_DNN(positive_input)
encoded_negative = Shared_DNN(negative_input)
positive_dist = Lambda(euclidean_distance, name='pos_dist')([encoded_anchor, encoded_positive])
negative_dist = Lambda(euclidean_distance, name='neg_dist')([encoded_anchor, encoded_negative])
tertiary_dist = Lambda(euclidean_distance, name='ter_dist')([encoded_positive, encoded_negative])
stacked_dists = Lambda(lambda vects: K.stack(vects, axis=1), name='stacked_dists')([positive_dist, negative_dist, tertiary_dist])
model = Model([anchor_input, positive_input, negative_input], stacked_dists, name='triple_siamese')
model.compile(loss=triplet_loss, optimizer=adam_optim, metrics=[accuracy])
history = model.fit([Anchor,Positive,Negative],y=Y_dummy,validation_data=([Anchor_test,Positive_test,Negative_test],Y_dummy2), batch_size=128, epochs=25)
我知道,一旦使用三元组训练模型,评估实际上不应该要求使用三元组。然而,我该如何进行这种重塑呢?
因为这是一个 SNN,所以我想将两个输入输入model.evaluate
,以及表示两个输入是否相似的分类变量(1 = similar, 0 = not similar)
.
所以基本上,我想要model.evaluate(input1, input2, y_label)
。但我不确定如何用我训练的模型得到这个。如上所示,我使用三个输入进行训练:model.fit([Anchor,Positive,Negative],y=Y_dummy ... )
.
我知道我应该保存训练模型的权重,但我只是不知道将权重加载到哪个模型上。
非常感谢您的帮助!
EDIT:
我知道以下预测方法,但我不是在寻找预测,我希望使用model.evaluate
因为我想获得模型损失/准确性的一些最终衡量标准。此外,这种方法仅将锚点输入到模型中(而我对文本相似性感兴趣,因此想要输入 2 个输入)
eval_model = Model(inputs=anchor_input, outputs=encoded_anchor)
eval_model.load_weights('weights.hdf5')