在 Keras/Tensorflow 自定义损失函数中使用额外的“可训练”变量

2024-03-01

我知道如何在 Keras 中使用附加输入(而不是标准输入)编写自定义损失函数y_true, y_pred配对,见下文。我的问题是输入损失函数可训练的变量(其中一些)是损失梯度的一部分,因此应该更新。

我的解决方法是:

  • 输入网络的虚拟输入NXV大小在哪里N是观测值的数量,V附加变量的数量
  • Add a Dense() layer dummy_output这样 Keras 就会跟踪我的V“权重”
  • 使用该层的V我的真实输出层的自定义损失函数中的权重
  • 为此使用虚拟损失函数(仅返回 0.0 和/或权重 0.0)dummy_output层所以我的V“权重”仅通过我的自定义损失函数更新

我的问题是:有没有更自然的类似 Keras/TF 的方法来做到这一点?因为它感觉很做作,更不用说容易出现错误。

我的解决方法示例:

(是的,我知道这是一个非常愚蠢的自定义损失函数,实际上事情要复杂得多)

import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import EarlyStopping
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Input
from tensorflow.keras import Model

n_col = 10
n_row = 1000
X = np.random.normal(size=(n_row, n_col))
beta = np.arange(10)
y = X @ beta

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# my custom loss function accepting my dummy layer with 2 variables
def custom_loss_builder(dummy_layer):
    def custom_loss(y_true, y_pred):
        var1 = dummy_layer.trainable_weights[0][0]
        var2 = dummy_layer.trainable_weights[0][1]
        return var1 * K.mean(K.square(y_true-y_pred)) + var2 ** 2 # so var2 should get to zero, var1 should get to minus infinity?
    return custom_loss

# my dummy loss function
def dummy_loss(y_true, y_pred):
    return 0.0

# my dummy input, N X V, where V is 2 for 2 vars
dummy_x_train = np.random.normal(size=(X_train.shape[0], 2)) 

# model
inputs = Input(shape=(X_train.shape[1],))
dummy_input = Input(shape=(dummy_x_train.shape[1],))
hidden1 = Dense(10)(inputs) # here only 1 hidden layer in the "real" network, assume whatever network is built here
output = Dense(1)(hidden1)
dummy_output = Dense(1, use_bias=False)(dummy_input)
model = Model(inputs=[inputs, dummy_input], outputs=[output, dummy_output])

# compilation, notice zero loss for the dummy_output layer
model.compile(
  loss=[custom_loss_builder(model.layers[-1]), dummy_loss],
  loss_weights=[1.0, 0.0], optimizer= 'adam')

# run, notice y_train repeating for dummy_output layer, it will not be used, could have created dummy_y_train as well
history = model.fit([X_train, dummy_x_train], [y_train, y_train],
                    batch_size=32, epochs=100, validation_split=0.1, verbose=0,
                   callbacks=[EarlyStopping(monitor='val_loss', patience=5)])

似乎确实可以正常工作,无论起始值如何var1 and var2(初始化的dummy_output层)他们渴望负inf and 0分别:

(该图来自迭代运行模型并保存这两个权重,如下所示)

var1_list = []
var2_list = []
for i in range(100):
    if i % 10 == 0:
        print('step %d' % i)
    model.fit([X_train, dummy_x_train], [y_train, y_train],
              batch_size=32, epochs=1, validation_split=0.1, verbose=0)
    var1, var2 = model.layers[-1].get_weights()[0]
    var1_list.append(var1.item())
    var2_list.append(var2.item())

plt.plot(var1_list, label='var1')
plt.plot(var2_list, 'r', label='var2')
plt.legend()
plt.show()

在这里回答我自己的问题,经过几天的努力,我让它在没有虚拟输入的情况下工作,我认为这要好得多,并且应该是“规范”方式,直到 Keras/TF 简化过程。 Keras/TF 文档就是这样做的here https://tensorflow.google.cn/guide/keras/train_and_evaluate#handling_losses_and_metrics_that_dont_fit_the_standard_signature.

使用外部损失函数的关键可训练的变量是通过使用自定义损失/输出来实现的Layer其中有self.add_loss(...) in its call()实施,像这样:

class MyLoss(Layer):
    def __init__(self, var1, var2):
        super(MyLoss, self).__init__()
        self.var1 = K.variable(var1) # or tf.Variable(var1) etc.
        self.var2 = K.variable(var2)
    
    def get_vars(self):
        return self.var1, self.var2
    
    def custom_loss(self, y_true, y_pred):
        return self.var1 * K.mean(K.square(y_true-y_pred)) + self.var2 ** 2
    
    def call(self, y_true, y_pred):
        self.add_loss(self.custom_loss(y_true, y_pred))
        return y_pred

现在请注意MyLoss层需求two输入,实际y_true和预测的y直到那时:

inputs = Input(shape=(X_train.shape[1],))
y_input = Input(shape=(1,))
hidden1 = Dense(10)(inputs)
output = Dense(1)(hidden1)
my_loss = MyLoss(0.5, 0.5)(y_input, output) # here can also initialize those var1, var2
model = Model(inputs=[inputs, y_input], outputs=my_loss)

model.compile(optimizer= 'adam')

最后,正如 TF 文档提到的,在这种情况下,您不必指定loss or y in the fit()功能:

history = model.fit([X_train, y_train], None,
                    batch_size=32, epochs=100, validation_split=0.1, verbose=0,
                    callbacks=[EarlyStopping(monitor='val_loss', patience=5)])

再次请注意y_train进入fit()作为输入之一。

现在它可以工作了:

var1_list = []
var2_list = []
for i in range(100):
    if i % 10 == 0:
        print('step %d' % i)
    model.fit([X_train, y_train], None,
              batch_size=32, epochs=1, validation_split=0.1, verbose=0)
    var1, var2 = model.layers[-1].get_vars()
    var1_list.append(var1.numpy())
    var2_list.append(var2.numpy())

plt.plot(var1_list, label='var1')
plt.plot(var2_list, 'r', label='var2')
plt.legend()
plt.show()

(我还应该提到这个特定的模式var1, var2很大程度上取决于它们的初始值,如果var1的初始值大于 1,实际上不会减少,直到负数inf)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 Keras/Tensorflow 自定义损失函数中使用额外的“可训练”变量 的相关文章

  • 静态文件配置不正确

    我已经在 Heroku 上部署了简单的博客应用程序 它运行在Django 1 8 4 我在静态文件方面遇到了一些问题 当打开我的应用程序时 我看到Application Error页面 所以我尝试调试它并发现当我提交到 Heroku 时它无
  • 使用信号时出现 django TransactionManagementError

    我有一个与 django 的用户和 UserInfo 一对一的字段 我想订阅用户模型上的 post save 回调函数 以便我也可以保存 UserInfo receiver post save sender User def saveUse
  • 创建一个打开文件并创建字典的函数

    我有一个正在处理的文件 我想创建一个读取文件并将内容放入字典中的函数 然后该字典需要通过 main 函数传递 这是主程序 它无法改变 我所做的一切都必须与主程序配合 def main sunspot dict file str raw in
  • 如何在Python中的BeautifulSoup4中使用.next_sibling时忽略空行

    由于我想删除 html 网站中重复的占位符 因此我使用 BeautifulSoup 的 next sibling 运算符 只要重复项位于同一行 就可以正常工作 参见数据 但有时它们之间有一个空行 所以我希望 next sibling 忽略它
  • 在 Python 3 中动态导入模块的问题

    我遇到的情况是 在我的 Python 3 项目中 在运行时必须包含某些模块 我在用着importlib import module为了这 第二次更新 我确实找到了一种方法来做一些接近我想要的事情 一些额外的代码可能会使我的一些链接稍微偏离一
  • 通过鼻子测试检查某个函数是否发出警告

    我正在使用编写单元测试nose http somethingaboutorange com mrl projects nose 0 11 2 我想检查函数是否引发警告 该函数使用warnings warn 这是很容易就能做到的事情吗 def
  • 使用字母而不是数字进行顺序计数[重复]

    这个问题在这里已经有答案了 我需要一种方法 将字符串 递增 到 z 然后将 aa 递增到 az 然后将 ba 递增到 bz 依此类推 就像 Excel 工作表中的列一样 我将向该方法提供前一个字符串 它应该增加到下一个字母 PSEUDO C
  • 使用 Python 的文本中的词频但忽略停用词

    这给了我文本中单词的频率 fullWords re findall r w allText d defaultdict int for word in fullWords d word 1 finalFreq sorted d iterit
  • Selenium:等到 WebElement 中的文本发生变化

    我在用着selenium使用Python 2 7 从网页上的搜索框检索内容 搜索框动态检索结果并在框本身中显示结果 from selenium import webdriver from selenium webdriver common
  • 如何使用python读取最后一行的特定位置

    我有一个太大的 txt 文件 并且有几行类似的行 如下所示 字1 字2 字3 字4 553 75 我对位置 4 值 感兴趣 即最后一行 553 75 我的文件文本 word1 word2 word3 word4 553 20 word1 w
  • 如何使用 msgpack 进行读写?

    如何序列化 反序列化字典data with msgpack http msgpack org The Python 文档 http msgpack python readthedocs io en latest badge latest似乎
  • 在Python中计算结构体的CRC

    我有以下结构 来自 C 中的 NRPE 守护程序代码 typedef struct packet struct int16 t packet version int16 t packet type uint32 t crc32 value
  • 如何在 Python 中执行相当于预处理器指令的操作?

    有没有办法在 Python 中执行以下预处理器指令 if DEBUG lt do some code gt else lt do some other code gt endif There s debug 这是编译器预处理的特殊值 if
  • 在tensorflow .ckpt文件中使用预训练模型

    我有一个 ckpt 文件 我只想得到 cnn 的权重 我已经从 ckpt 检查点文件中进行了训练 inception resnet v2 2016 08 30 import tensorflow as tf saver tf train S
  • 如何将 pytest 装置与 django TestCase 一起使用

    我如何在TestCase方法 类似问题的几个答案似乎暗示我的例子应该有效 import pytest from django test import TestCase from myapp models import Category py
  • 从 csv 中读取 pandas 数据帧,以非固定标头开始

    我有许多数据文件是由我的实验室中使用的一些相当黑客的脚本生成的 该脚本非常有趣 因为它在标头之前附加的行数因文件而异 尽管它们具有相同的格式并具有相同的标头 我正在编写一个批处理来将所有这些文件处理为数据帧 如果我不知道位置 如何让 pan
  • 从 Python 中编译的正则表达式中提取命名组正则表达式模式

    我有一个 Python 正则表达式 其中包含多个命名组 但是 如果先前的组已匹配 则可能会错过与一组匹配的模式 因为似乎不允许重叠 举个例子 import re myText sgasgAAAaoasgosaegnsBBBausgisego
  • Jupyter Notebook 中的多处理与线程

    我试图测试这个例子here https ipywidgets readthedocs io en stable examples Widget 20Asynchronous html将其从线程更改为多处理 在 jupyter Noteboo
  • 在读/写二进制数据结构时访问位域

    我正在为二进制格式编写一个解析器 这种二进制格式涉及不同的表 这些表同样采用二进制格式 通常包含不同的字段大小 其中 50 100 个之间 大多数这些结构都有位域 并且在 C 语言中表示时看起来像这样 struct myHeader uns
  • Shap - 颜色条不显示在摘要图中

    显示summary plot时 不显示颜色条 shap summary plot shap values X train 我尝试过改变plot size 当绘图较高时 会出现颜色条 但它非常小 看起来不应该 shap summary plo

随机推荐