如何访问生成器提供的 Keras 自定义损失函数中的样本权重?

2024-02-27

我有一个生成器函数,可以无限循环某些图像目录并输出 3 元组的批次形式

[img1, img2], label, weight

where img1 and img2 are batch_size x M x N x 3张量,以及label and weight是每个batch_sizex 1 张量。

我将这个生成器提供给fit_generator使用 Keras 训练模型时的函数。

对于这个模型,我有一个自定义的余弦对比损失函数,

def cosine_constrastive_loss(y_true, y_pred):
    cosine_distance = 1 - y_pred
    margin = 0.9
    cdist = y_true * y_pred + (1 - y_true) * keras.backend.maximum(margin - y_pred, 0.0)
    return keras.backend.mean(cdist)

从结构上讲,我的模型一切正常。没有错误,并且它正在按预期消耗来自生成器的输入和标签。

但现在我正在寻求直接使用每个批次的权重参数并在其中执行一些自定义逻辑cosine_contrastive_loss基于样品特定的重量。

如何在执行损失函数时从一批样本的结构中访问此参数?

请注意,由于它是一个无限循环的生成器,因此不可能预先计算权重或动态计算它们以将权重咖喱到损失函数中或生成它们。

它们必须与正在生成的样本一致生成,实际上我的数据生成器中有自定义逻辑,可以根据以下属性动态确定权重img1, img2 and label目前它们是为一批生成的。


手动训练循环替代方案

我唯一能想到的是手动训练循环,您可以自己获取权重。

有一个权重张量和一个不可变的批量大小:

weights = K.variable(np.zeros((batch_size,)))

在您的自定义损失中使用它们:

def custom_loss(true, pred):
    return someCalculation(true, pred, weights)

对于“生成器”:

for e in range(epochs):
    for s in range(steps_per_epoch):
        x, y, w = next(generator) #or generator.next(), not sure
        K.set_value(weights, w)

        model.train_on_batch(x, y)

For a keras.utils.Sequence:

for e in range(epochs):
    for s in range(len(generator)):
        x,y,w = generator[s]

        K.set_value(weights, w)
        model.train_on_batch(x,y)

我知道这个答案不是最佳的,因为它不会并行化从生成器获取数据,因为它发生在fit_generator。但这是我能想到的最好的简单解决方案。 Keras 没有公开权重,它们会自动应用在一些隐藏的源代码中。


让模型计算权重替代方案

如果可以计算权重x and y,您可以将此任务委托给损失函数本身。

这有点hacky,但可能有效:

input1 = Input(shape1)
input2 = Input(shape2)

# .... model creation .... #

model = Model([input1, input2], outputs)

让损失者能够获得input1 and input2:

def custom_loss(y_true, y_pred):
    w = calculate_weights(input1, input2, y_pred)
    # .... rest of the loss .... #

这里的问题是您是否可以根据输入将权重计算为张量。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何访问生成器提供的 Keras 自定义损失函数中的样本权重? 的相关文章

  • 如何使用一个模型中间层的输出作为另一个模型的输入?

    我训练一个模型A并尝试使用中间层的输出name layer x 作为模型的附加输入B 我尝试像 Keras 文档一样使用中间层的输出https keras io getting started faq how can i obtain th
  • Flask 中“缺少 CSRF 令牌”,但它在模板中呈现

    问题 当我尝试登录 使用 Flask login 时 我得到Bad Request The CSRF session token is missing但令牌正在呈现 在模板中 secret key 已设置 并且我在本地运行localhost
  • Keras model.predict 函数给出输入形状错误

    我已经在 Tensorflow 中实现了通用句子编码器 现在我正在尝试预测句子的类概率 我也将字符串转换为数组 Code if model model type universal classifier basic class probs
  • 为什么在连接两个字符串时 Python 比 C 更快?

    目前我想比较 Python 和 C 用来处理字符串的速度 我认为 C 应该比 Python 提供更好的性能 然而 我得到了完全相反的结果 这是 C 程序 include
  • App Engine NDB:如何访问属性的 verbose_name

    假设我有这个代码 class A ndb Model prop ndb StringProperty verbose name Something m A m prop a string value 当然 现在如果我打印 m prop 它会
  • 使用pathlib获取主目录

    翻看新的pathlib在 Python 3 4 中 我注意到没有任何简单的方法来获取用户的主目录 我能想到的获取用户主目录的唯一方法是使用旧的os path像这样的库 import pathlib from os import path p
  • 在linux上安装python ssl模块,无需重新编译

    是否可以在已经安装了 OpenSSL 的 Linux 机器上安装 python 的 SSL 模块 而无需重新编译 python 我希望它就像复制几个文件并将它们包含在库路径中一样简单 Python版本是2 4 3 谢谢 是否可以在已经安装了
  • ValueError:数据必须为正(boxcox scipy)

    我正在尝试将我的数据集转换为正态分布 0 8 298511e 03 1 3 055319e 01 2 6 938647e 02 3 2 904091e 02 4 7 422441e 02 5 6 074046e 02 6 9 265747e
  • 使用 subprocess.Popen() 或 subprocess.check_call() 时程序卡住

    我想从 python 运行一个程序并找到它的内存使用情况 为此 我正在使用 l a out lt in txt gt out txt p subprocess Popen l shell False stdout subprocess PI
  • 在请求中设置端口

    我正在尝试利用cgminer使用 Python 的 API 我对利用requests图书馆 我了解如何做基本的事情requests but cgminer想要更具体一点 我想缩小 import socket import json sock
  • 使用 if 语句的网格网格和用户定义函数的真值不明确

    假设我有一个函数f x y 足够光滑 然而 有些值仅在有限的意义上存在 以sin x x的价值x 0只存在于极限 x gt 0 中 在一般情况下 我用一个来处理这个问题if陈述 如果我在情节中使用它meshgrid我收到一条错误消息 Val
  • Python 属性和 Swig

    我正在尝试使用 swig 为一些 C 代码创建 python 绑定 我似乎遇到了一个问题 试图从我拥有的一些访问器函数创建 python 属性 方法如下 class Player public void entity Entity enti
  • Python:在字典中查找具有唯一值的键?

    我收到一个字典作为输入 并且想要返回一个键列表 其中字典值在该字典的范围内是唯一的 我将用一个例子来澄清 假设我的输入是字典 a 构造如下 a dict a cat 1 a fish 1 a dog 2 lt unique a bat 3
  • python Recipe:列出最接近等于值的项[关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 考虑像这样的列表 0 3 7 10 12 15 19 21 我想获得最接近任何值的最近的最小数字 所以如果我通过4 我会得到3 如果我
  • 为什么 Collections.counter 这么慢?

    我正在尝试解决罗莎琳德的基本问题 即计算给定序列中的核苷酸 并在列表中返回结果 对于那些不熟悉生物信息学的人来说 它只是计算字符串中 4 个不同字符 A C G T 出现的次数 我期望collections Counter是最快的方法 首先
  • 从 wxPython 事件处理程序中调用函数

    我正在努力寻找一种在 wxPython 事件处理函数中使用函数的方法 假设我有一个按钮 单击该按钮时 它会使用事件处理程序运行一个名为 OnRun 的函数 但是 用户忘记单击 OnRun 按钮之前的 RadionButton 我想弹出一个
  • Python 读取未格式化的直接访问 Fortran 90 给出不正确的输出

    这是数据的写入方式 它是一个二维浮点矩阵 我不确定大小 open unit 51 file rmsd nn output form unformatted access direct status replace recl Npoints
  • 两种 ODE 求解器之间的差异

    我想知道 两者之间有什么区别ODEINT and solve ivp用于求解微分方程 它们之间有什么优点和缺点 f1 solve ivp f 0 1 y0 y0 is the initial point f2 odeint f y0 0 1
  • 如何同时接受int和float类型的输入?

    我正在制作一个货币转换器 如何让 python 同时接受整数和浮点数 我就是这样做的 def aud brl amount From to ER 0 42108 if amount int if From strip aud and to
  • 基于值的 matplotlib 条形图颜色

    有没有一种方法可以根据条形图的值对条形图的条形进行着色 例如 values below 0 5 red values between 0 5 to 0 green values between 0 to 08 blue etc 我找到了一些

随机推荐