如何防止 Keras 在训练期间计算指标

2024-05-08

我正在使用 Tensorflow/Keras 2.4.1,并且有一个(无监督的)自定义指标,它将我的几个模型输入作为参数,例如:

model = build_model() # returns a tf.keras.Model object
my_metric = custom_metric(model.output, model.input[0], model.input[1])
model.add_metric(my_metric)
[...]
model.fit([...]) # training with fit

然而,碰巧的是custom_metric非常昂贵,所以我希望仅在验证期间计算它。我找到了这个answer https://stackoverflow.com/a/60829012/6315123但我几乎不明白如何使解决方案适应我的指标,该指标使用多个模型输入作为参数,因为update_state方法好像不太灵活。

在我的上下文中,除了编写我自己的训练循环之外,是否有办法避免在训练期间计算我的指标? 另外,我很惊讶我们无法本机指定 Tensorflow 某些指标只能在验证时计算,这有什么原因吗?

此外,由于模型经过训练来优化损失,并且训练数据集不应用于评估模型,我什至不明白为什么默认情况下 Tensorflow 在训练期间计算指标。


我认为仅在验证时计算指标的最简单解决方案是使用自定义回调。

在这里我们定义我们的虚拟回调:

class MyCustomMetricCallback(tf.keras.callbacks.Callback):

    def __init__(self, train=None, validation=None):
        super(MyCustomMetricCallback, self).__init__()
        self.train = train
        self.validation = validation

    def on_epoch_end(self, epoch, logs={}):

        mse = tf.keras.losses.mean_squared_error

        if self.train:
            logs['my_metric_train'] = float('inf')
            X_train, y_train = self.train[0], self.train[1]
            y_pred = self.model.predict(X_train)
            score = mse(y_train, y_pred)
            logs['my_metric_train'] = np.round(score, 5)

        if self.validation:
            logs['my_metric_val'] = float('inf')
            X_valid, y_valid = self.validation[0], self.validation[1]
            y_pred = self.model.predict(X_valid)
            val_score = mse(y_pred, y_valid)
            logs['my_metric_val'] = np.round(val_score, 5)

给定这个虚拟模型:

def build_model():

  inp1 = Input((5,))
  inp2 = Input((5,))
  out = Concatenate()([inp1, inp2])
  out = Dense(1)(out)

  model = Model([inp1, inp2], out)
  model.compile(loss='mse', optimizer='adam')

  return model

和这个数据:

X_train1 = np.random.uniform(0,1, (100,5))
X_train2 = np.random.uniform(0,1, (100,5))
y_train = np.random.uniform(0,1, (100,1))

X_val1 = np.random.uniform(0,1, (100,5))
X_val2 = np.random.uniform(0,1, (100,5))
y_val = np.random.uniform(0,1, (100,1))

您可以使用自定义回调来计算训练和验证上的指标:

model = build_model()

model.fit([X_train1, X_train2], y_train, epochs=10, 
          callbacks=[MyCustomMetricCallback(train=([X_train1, X_train2],y_train), validation=([X_val1, X_val2],y_val))])

仅在验证时:

model = build_model()

model.fit([X_train1, X_train2], y_train, epochs=10, 
          callbacks=[MyCustomMetricCallback(validation=([X_val1, X_val2],y_val))])

仅限火车上:

model = build_model()

model.fit([X_train1, X_train2], y_train, epochs=10, 
          callbacks=[MyCustomMetricCallback(train=([X_train1, X_train2],y_train))])

只记得这一点回调一次性评估指标在数据上,就像 keras 默认计算的任何指标/损失一样validation_data.

here https://colab.research.google.com/drive/1ZT9jDHxTQLWzc8rMJ5dfgUT77Rr3j-GI?usp=sharing是正在运行的代码。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何防止 Keras 在训练期间计算指标 的相关文章

随机推荐

  • IntelliJ IDEA 中查找方法/函数的快捷方式是什么?

    I know that Ctrl N is to find classes and it is very useful But what about methods ctrl F12 cmd F12 on macOS will show a
  • 可滚动Div,哪些元素可以看到

    我们有一个带有 CSS 的可滚动 divhieght 40px 里面有多个LIheight 20px div li title I1 item1 li li title I2 item2 li li title I3 item3 li li
  • 轮询时承诺异步等待

    我正在尝试将使用承诺 和轮询 的函数转换为异步函数 但我不太确定它是如何工作的 我有这个 function myFunction return new Promise resolve gt stuff here var poll setIn
  • 如何在 Bixby 输入视图中使用语音输入选择?

    目前 我设置了一个提示来收集用户的电子邮件 它在选择输入视图中提供存储在用户配置文件中的选项 但是 如果用户通过语音与 Bixby 交互 例如 可能他们的手很忙 是否有办法为这些选项提供别名 以便用户只需说 家庭 即可获取家庭电子邮件或 工
  • 无锁算法中的 ABA

    我明白了ABA http en wikipedia org wiki ABA problem问题 但我无法理解的是 他们说在语言中自动垃圾收集它可能不会展示 所以我的问题是 自动垃圾收集如何防止ABA问题的发生 在java中是否可能 如果可
  • 连接错误 - SQLSTATE[HY000] [2002] 操作超时

    我在从 Windows 2008 R2 应用程序服务器连接到也在 Windows 2008 R2 上运行的 MySQL 服务器时遇到问题 Laravel 应用程序报告错误 exception PDOException with messag
  • gnuplot 中的 output.png 不如提示 shell 中的图好

    我经常绘制图表gnuplot提示 shell 如下所示 gunuplot gt plot sin x with linespoints pointtype 3 出现的数字很棒 今天 我将图表保存在 png文件 像这样 gnuplot gt
  • 我如何解决语义错误:“类没有名为..”的关联

    我正在关注 symblog symfony2 教程的第 5 部分 http tutorial symblog co uk docs customising the view more with twig html 标题下 主页 博客和评论
  • 是否可以在 gnuplot 中设置标签相对于键的位置?

    我的情节的本质是这样的 绝对的标签并不能真正发挥作用 我无法限制 y 中的范围 所以想知道是否有办法将我的标签文本包含在键内或将其相对于键放置 即下面 set term png enhanced size 1024 768 set titl
  • 运行时错误:范围自动筛选上的“1004”

    我想用 VBA 做什么 使用数组过滤表并删除行 我的数组有 4 个元素 在循环中更改为有 5 个不同的集合 正在过滤的列有 5 个元素 我只想得到 1 这是一个循环 它将创建 5 个报告 每个报告根据第 29 列过滤不同的元素 如果在调试模
  • 如何在调用处替换内联函数代码?

    我想知道内联函数调用是如何被内联代码替换的 我在一些书中读到编译器可能会将内联函数视为普通函数 任何人都可以解释内联函数是如何工作的 来自 C 常见问题解答 http www parashift com c faq lite inline
  • 属性编辑器未向 PropertyEditorManager 注册:自定义标记调用时出错

    调用我的时出现以下错误testtag jsp org apache jasper JasperException 无法将属性 att1 的字符串 转换为类 javax servlet jsp tagext JspFragment 属性编辑器
  • 缓动不适用于toggleClass() 或addClass()

    我有一个在页面上显示和隐藏实用工具栏的功能 我想将其动画化 这不是动画 类 标志 是空的 min 类只是更改背景图像以及实用工具栏的高度和绝对位置 我究竟做错了什么 document ready function var ubar ccUt
  • 为什么 Java BufferedReader() 不能正确读取阿拉伯文和中文字符?

    我正在尝试读取一个每行包含英文和阿拉伯字符的文件以及另一个每行包含英文和中文字符的文件 然而 阿拉伯文和中文的字符无法正确显示 它们只是显示为问号 知道我该如何解决这个问题吗 这是我用于阅读的代码 try String sCurrentLi
  • 比较 C# 中 DateTime 的二进制表示形式

    我有一个DateTime表示为长 8 个字节 来自DateTime ToBinary 我们称之为dateTimeBin 是否有一种最佳方法可以删除时间信息 我只关心日期 以便我可以将其与一天的开始进行比较 假设我们将此样本值作为一天的开始
  • 为什么 Rust 不允许在一种类型上复制和删除特征?

    From the book https doc rust lang org book 2018 edition ch04 01 what is ownership html stack only data copy Rust 不允许我们用C
  • 如何在 Mac OS X 10.9.5 上以编程方式读取低功耗蓝牙传输的数据?

    我正在尝试阅读蓝牙低功耗 http www bluetooth com Pages low energy tech info aspx使用 Ruby 以编程方式传输数据 低功耗蓝牙技术不支持标准规范 v4 0 中的串行端口配置文件 SPP
  • 如何使用 Typescript 从 mui 扩展调色板

    我正在尝试扩展 mui 提供的调色板 覆盖主色 次要颜色等效果很好 但如果我想在之后创建一组自定义颜色 我不知道如何使其工作 有很多没有打字稿的例子 但是当这个人进入游戏时 事情就变得更加棘手 假设我有这个 主题 tsx palette p
  • CSS 轮廓宽度不起作用

    我正在尝试将输入元素的轮廓宽度设置为焦点 无论我的设置如何 轮廓宽度都保持不变 就像它是无法更改的默认设置一样 这是来自 codepen 的示例 http codepen io FrenkyB pen mEaEyL editors 1100
  • 如何防止 Keras 在训练期间计算指标

    我正在使用 Tensorflow Keras 2 4 1 并且有一个 无监督的 自定义指标 它将我的几个模型输入作为参数 例如 model build model returns a tf keras Model object my met