在 Keras 中,如何修改每个批次的损失(使用训练期间应运行的额外代码)

2024-04-21

使用此自定义回调,我可以 1) 查看训练期间的损失 2) 访问正在训练的模型

class ChangeBatchLoss(tf.keras.callbacks.Callback):
    def on_train_batch_begin(self, batch, logs=None):
        if 't_loss' in logs:
            print(logs, file=sys.stderr)
            print(self.model, file=sys.stderr)

我的问题是:是否可以在训练期间修改损失本身? (我想执行一些额外的计算并添加/减去损失(在我的代码中,“损失”对应于日志[“t_loss”]显示的值。

任何想法?

Thanks


  1. 为相关模型创建自定义模型
  2. 覆盖函数 _make_train_function 并修改 self.metrics_tensors 或 self.total_loss
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 Keras 中,如何修改每个批次的损失(使用训练期间应运行的额外代码) 的相关文章

随机推荐