使用三元组损失连体神经网络模型进行评估(model.evaluate)-tensorflow

2024-01-01

我训练了一个使用三重态损失的连体神经网络。这很痛苦,但我想我做到了。然而,我很难理解如何用这个模型进行评估。

The SNN:

def triplet_loss(y_true, y_pred):
    margin = K.constant(1)
    return K.mean(K.maximum(K.constant(0), K.square(y_pred[:,0]) - 0.5*(K.square(y_pred[:,1])+K.square(y_pred[:,2])) + margin))

def euclidean_distance(vects):
    x, y = vects
    return K.sqrt(K.maximum(K.sum(K.square(x - y), axis=1, keepdims=True), K.epsilon()))
anchor_input = Input((max_len, ), name='anchor_input')
positive_input = Input((max_len, ), name='positive_input')
negative_input = Input((max_len, ), name='negative_input')

Shared_DNN = create_base_network(embedding_dim = EMBEDDING_DIM, max_len=MAX_LEN, embed_matrix=embed_matrix)

encoded_anchor = Shared_DNN(anchor_input)
encoded_positive = Shared_DNN(positive_input)
encoded_negative = Shared_DNN(negative_input)

positive_dist = Lambda(euclidean_distance, name='pos_dist')([encoded_anchor, encoded_positive])
negative_dist = Lambda(euclidean_distance, name='neg_dist')([encoded_anchor, encoded_negative])
tertiary_dist = Lambda(euclidean_distance, name='ter_dist')([encoded_positive, encoded_negative])

stacked_dists = Lambda(lambda vects: K.stack(vects, axis=1), name='stacked_dists')([positive_dist, negative_dist, tertiary_dist])

model = Model([anchor_input, positive_input, negative_input], stacked_dists, name='triple_siamese')

model.compile(loss=triplet_loss, optimizer=adam_optim, metrics=[accuracy])
history = model.fit([Anchor,Positive,Negative],y=Y_dummy,validation_data=([Anchor_test,Positive_test,Negative_test],Y_dummy2), batch_size=128, epochs=25)

我知道,一旦使用三元组训练模型,评估实际上不应该要求使用三元组。然而,我该如何进行这种重塑呢?

因为这是一个 SNN,所以我想将两个输入输入model.evaluate,以及表示两个输入是否相似的分类变量(1 = similar, 0 = not similar).

所以基本上,我想要model.evaluate(input1, input2, y_label)。但我不确定如何用我训练的模型得到这个。如上所示,我使用三个输入进行训练:model.fit([Anchor,Positive,Negative],y=Y_dummy ... ) .

我知道我应该保存训练模型的权重,但我只是不知道将权重加载到哪个模型上。

非常感谢您的帮助!

EDIT: 我知道以下预测方法,但我不是在寻找预测,我希望使用model.evaluate因为我想获得模型损失/准确性的一些最终衡量标准。此外,这种方法仅将锚点输入到模型中(而我对文本相似性感兴趣,因此想要输入 2 个输入)

eval_model = Model(inputs=anchor_input, outputs=encoded_anchor)
eval_model.load_weights('weights.hdf5')


考虑到eval_model被训练来生成嵌入,我认为应该很好地使用以下方法来评估两个嵌入之间的相似性余弦相似度 https://www.tensorflow.org/api_docs/python/tf/keras/losses/cosine_similarity.

根据TF文档,余弦相似度是-1到1之间的数字。当它是接近-1的负数时,表示相似度更大。当它是接近1的正数时,表明差异较大。

我们可以简单地计算所有可用样本的正输入和负输入之间的余弦相似度。当余弦相似度 (1 = similar, 0 = not similar)。最后,可以计算二进制精度作为最终指标。

我们可以使用 TF 进行所有计算,而无需使用model.evaluate.

eval_model = Model(inputs=anchor_input, outputs=encoded_anchor)
eval_model.load_weights('weights.hdf5')

cos_sim = tf.keras.losses.cosine_similarity(
    eval_model(X1), eval_model(X2)
).numpy().reshape(-1,1)

accuracy = tf.reduce_mean(tf.keras.metrics.binary_accuracy(Y, -cos_sim, threshold=0))

另一种方法 https://keras.io/examples/vision/siamese_network/包括计算锚点和正图像之间的余弦相似度,并将其与锚点和负图像之间的相似度进行比较。

eval_model = Model(inputs=anchor_input, outputs=encoded_anchor)
eval_model.load_weights('weights.hdf5')

positive_similarity = tf.keras.losses.cosine_similarity(
    eval_model(X_anchor), eval_model(X_positive)
).numpy().mean()

negative_similarity = tf.keras.losses.cosine_similarity(
    eval_model(X_anchor), eval_model(X_negative)
).numpy().mean()

我们应该期望锚点和正图像之间的相似度大于锚点和负图像之间的相似度。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用三元组损失连体神经网络模型进行评估(model.evaluate)-tensorflow 的相关文章

随机推荐

  • 从指针到成员的映射

    Note in case this feels like an X Y problem scroll below the separator for how I arrived at this question 我正在寻找一种方法来存储指向
  • 如何在 QML 中创建矩形滚动条

    就像网页一样 当内容超出矩形时 就会出现滚动条 还有其他人可以帮助我吗 我尝试过使用列表视图 但无法在矩形中使用它 文档中有一个例子 如何使用ScrollBar https doc qt io qt 5 qml qtquick contro
  • 如何使用 Intellij 插件创建自定义实时模板

    我想创建一个可与我的插件一起使用的自定义实时模板 我知道如何使用 设置 对话框创建自定义实时模板 但我希望能够将实时模板作为我的插件的一部分分发 怎么样实时模板在插件中定义 在应用程序中注册它的入口点在哪里 Thanks 使用12 1 5
  • SASS:获取现有背景字符串的值并添加到其中?

    我想在 SASS Compass 中额外构建背景 而不考虑现有的背景字符串 我可以通过写入全局变量来完成 但看起来很草率 Pseudo mixin add icon add a background icon mixin add gradi
  • 引入先前证明的定理作为假设

    假设我已经在coq中证明了某个定理 稍后我想将其作为假设引入到另一个定理的证明中 有没有一种简洁的方法来做到这一点 当我想做一些诸如案例证明之类的事情时 我通常会出现这种需要 我发现做到这一点的一种方法是assert陈述定理 然后立即证明它
  • 如何使用 Zeromq 的 inproc 和 ipc 传输?

    我是 ZERMQ 的新手 ZeroMQ 具有 TCP INPROC 和 IPC 传输 我正在寻找在 Winx64 和 python 2 7 中使用 python 和 inproc 的示例 这些示例也可以用于 Linux 另外 我一直在寻找
  • 无法加载 Boost.Python 模块 - 未定义的符号

    我有一个用 C 编写的库 需要从 Python 访问 所以我使用 Boost Python 包装它 我可以毫无问题地将我的库编译成 Boost so 文件 但是当我尝试将其加载到 Python 中时 使用import tropmodboos
  • 改造 404 未找到 Web api

    我有一个网络 API 和一个应用程序 所以我想要一个注册应用程序 但我有一个问题 我用的是天蓝色的 有我的registerapi 界面 FormUrlEncoded POST application json public void ins
  • 如何在 Yocto 构建中将第三方库添加为包

    我有一个不知名的库 并且没有适用于该库的包https github com dailab libsml https github com dailab libsml通常我通过以下方式在我的设备上安装这个库make install如何将此库作
  • 获取数学函数作为用户的输入

    我需要知道如何将字符串输入传输到可执行函数 例如 用户编写字符串 x Sin x 2 然后程序将其作为函数 可以计算给定 x 的值 可以绘制该函数的推导图等 我读到有一个名为scitools stringfunction 但据我所知该模块在
  • 返回时如何跳过浏览器历史记录中的页面?

    我有一个带有路由器的 Angular 2 应用程序 假设用户位于应用程序中的页面 A 然后导航到页面 B 然后导航到页面 C 此时 当他单击浏览器上的 后退 按钮时 我希望他返回到页面 A 跳过 B 我怎样才能实现它 当从 B 导航到 C
  • 如何将环境变量传递给使用自定义容器创建的 gcloud beta ai 自定义作业 (Vertex AI)

    我正在谷歌的 Vertex AI 中运行自定义训练作业 一个简单的gcloud执行自定义作业的命令将使用类似以下语法的内容 可以查看该命令的完整文档here https cloud google com sdk gcloud referen
  • 停止 VS 2010 在 else 关键字后自动创建大括号

    我正在使用 VS 2010 当我输入 else 然后它自动返回行并添加大括号时 我总是感到恼火 就像是 else 我无法想象我是唯一一个经常在其他内容后面加上俏皮话并且不喜欢大括号的人 我该如何阻止这种情况发生 我也发现这种行为非常烦人 我
  • URL编码iOS NSURL错误

    在桌面上的 Firefox Chrome 浏览器中打开的 URL 在 iPhone 上的 WebView 中无法打开 该 URL 据称正在访问 GET 请求 创建 NSURL 时不进行百分号转义 则不会生成 url 使用percentesc
  • php:将整个 $_POST 变量保存在会话中

    这是否有效 SESSION pictures rateAlbum POST POST 我想一次性保存会话中的所有 POST 数据 编辑 哦 反过来呢 POST SESSION pictures rateAlbum POST 是的你可以 如果
  • 使用 chrome 暂停下载无法按预期工作

    我试图暂停下载 但不起作用 文件已下载 这是我的代码 在我的后台脚本中 chrome downloads onCreated addListener function e chrome downloads pause e id Here c
  • 使用 Response.End(false) 与 ApplicationInstance.CompleteRequest() 之间/用例之间有什么区别

    我遇到了一个讨论使用的问题ApplicationInstance CompleteRequest 以避免ThreadAbortException被抛出时Response End 叫做 过去 为了避免我上面提到的异常错误 我使用了这个重载 R
  • 使用 .next() 或 .nextLine() 的 java 字符串变量

    以下是我的源代码 package functiontest import java io BufferedWriter import java io File import java io FileWriter import java io
  • 使用自定义名称创建工作簿而不将其保存到磁盘

    是否可以创建具有自定义名称的工作簿而不将其保存到磁盘 我想避免使用默认的 工作簿 x 名称 但我不想要求用户保存工作簿 如果我自动将其保存在某个临时文件中 则用户单击 保存 时将不会出现 另存为 对话框 这可能会令人困惑 只需创建工作簿而不
  • 使用三元组损失连体神经网络模型进行评估(model.evaluate)-tensorflow

    我训练了一个使用三重态损失的连体神经网络 这很痛苦 但我想我做到了 然而 我很难理解如何用这个模型进行评估 The SNN def triplet loss y true y pred margin K constant 1 return