Keras 中的 add_loss 函数的用途是什么?

2024-02-22

目前,我偶然发现了变分自动编码器,并尝试使用 keras 让它们在 MNIST 上工作。我找到了一个教程github https://github.com/keras-team/keras/blob/master/examples/variational_autoencoder.py.

我的问题涉及以下代码行:

# Build model
vae = Model(x, x_decoded_mean)

# Calculate custom loss
xent_loss = original_dim * metrics.binary_crossentropy(x, x_decoded_mean)
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
vae_loss = K.mean(xent_loss + kl_loss)

# Compile
vae.add_loss(vae_loss)
vae.compile(optimizer='rmsprop')

为什么使用 add_loss 而不是将其指定为编译选项?就像是vae.compile(optimizer='rmsprop', loss=vae_loss)似乎不起作用并抛出以下错误:

ValueError: The model cannot be compiled because it has no loss to optimize.

这个函数和自定义损失函数有什么区别,我可以将其添加为 Model.fit() 的参数?

提前致谢!

P.S.:我知道 github 上有几个与此相关的问题,但其中大多数都是开放的且未评论。如果这个问题已经解决了,请分享链接!


Edit 1

我删除了向模型添加损失的行,并使用了编译函数的损失参数。现在看起来像这样:

# Build model
vae = Model(x, x_decoded_mean)

# Calculate custom loss
xent_loss = original_dim * metrics.binary_crossentropy(x, x_decoded_mean)
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
vae_loss = K.mean(xent_loss + kl_loss)

# Compile
vae.compile(optimizer='rmsprop', loss=vae_loss)

这会引发类型错误:

TypeError: Using a 'tf.Tensor' as a Python 'bool' is not allowed. Use 'if t is not None:' instead of 'if t:' to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

Edit 2

感谢@MarioZ 的努力,我找到了解决方法。

# Build model
vae = Model(x, x_decoded_mean)

# Calculate custom loss in separate function
def vae_loss(x, x_decoded_mean):
    xent_loss = original_dim * metrics.binary_crossentropy(x, x_decoded_mean)
    kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    vae_loss = K.mean(xent_loss + kl_loss)
    return vae_loss

# Compile
vae.compile(optimizer='rmsprop', loss=vae_loss)

...

vae.fit(x_train, 
    x_train,        # <-- did not need this previously
    shuffle=True,
    epochs=epochs,
    batch_size=batch_size,
    validation_data=(x_test, x_test))     # <-- worked with (x_test, None) before

由于某些奇怪的原因,我必须在拟合模型时显式指定 y 和 y_test 。本来,我不需要这样做。生产的样品对我来说似乎很合理。

虽然我可以解决这个问题,但我仍然不知道这两种方法的区别和缺点是什么(除了需要不同的语法之外)。有人可以给我更多的见解吗?


我将尝试回答最初的问题:为什么model.add_loss()正在使用而不是指定自定义损失函数model.compile(loss=...).

Keras 中的所有损失函数始终采用两个参数y_true and y_pred。看看Keras中可用的各种标准损失函数的定义,它们都有这两个参数。它们是“目标”(许多教科书中的 Y 变量)和模型的实际输出。大多数标准损失函数都可以写成这两个张量的表达式。但一些更复杂的损失不能这样写。对于您的 VAE 示例,情况就是如此,因为损失函数还取决于附加张量,即z_log_var and z_mean,这不适用于损失函数。使用model.add_loss()没有这样的限制,并允许您编写依赖于许多其他张量的更复杂的损失,但它的不便之处在于更加依赖于模型,而标准损失函数适用于任何模型。

(注意:这里其他答案中提出的代码有些作弊,因为它们只是使用全局变量来潜入额外的所需依赖项。这使得损失函数不是数学意义上的真正函数。我认为这很重要不太干净的代码,我预计它更容易出错。)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Keras 中的 add_loss 函数的用途是什么? 的相关文章

随机推荐

  • 超时错误:400 StatusCode 错误:“要求失败:会话不活动。”

    我在用着Zeppelin v0 7 3笔记本运行Pyspark脚本 在一段中 我正在运行脚本来写入数据dataframe to a parquetBlob 文件夹中的文件 文件按国家 地区进行分区 数据帧的行数是99 452 829 当脚本
  • 词典顺序的定义? [复制]

    这个问题在这里已经有答案了 我目前正在阅读有关std next permutation http en cppreference com w cpp algorithm next permutation函数并遇到了术语 字典顺序 在特定的时
  • android 模拟器无法启动

    我正在尝试从 Android 虚拟设备管理器启动 Android 虚拟设备 Android 模拟器窗口打开 但屏幕仍显示在 Android 徽标上 并且没有进一步进展 在开始第二个 avd 之前它工作得很好 现在第一个 avd 和第二个 a
  • Spirit-Qi:如何编写非终结符解析器?

    我想写一个可以使用的解析器 作为 qi 扩展 通过my parser p1 p2 where p1 p2 是 qi 解析器表达式 其实我想实现一个best match解析器的工作方式类似于 qi 替代方案 但不选择第一个匹配规则 而是选择
  • 学习 jQuery 的 CSS 选择器

    我想学习 jQuery 在我看来 jQuery 只需选择你想要的元素 然后对其执行一些操作 但选择方式与CSS选择器很接近 而且我不熟悉CSS css选择器 因为我一直认为它相当不系统 我找不到任何规则 我对 CSS 选择器的了解如下 id
  • 在 Chrome 中,触发 $(document).ready() 时不会加载 资源。为什么?

    在 Firefox 和 IE 中 SVG SVG 文档在以下情况下检索 document ready 叫做 在 Chrome 中 getSVGDocument当以下情况时返回 null document ready 叫做 虽然似乎在大约 7
  • Htaccess 重写删除尾部斜杠

    Htaccess 以某种方式自动删除 url 末尾的所有尾部斜杠并只保留一个 例如http localhost api param1 http localhost api param1 变成http localhost api param1
  • 要安装多少个 wiki 实例?

    我被要求安装 Mediawiki 来保存公司内部网的文档 此外 我被要求安装several实例 每个贡献组一个 非技术用户获得一个 开发人员获得一个 管理人员获得一个 等等 我们的想法是为每个组提供单独的网络空间 有没有一种方法可以在一个实
  • 如何缓存从 Ajax 调用接收到的数据?

    我想缓存从服务器接收的数据 以便执行最少数量的 PHP MySQL 指令 我知道缓存选项是自动为 ajax 设置的 但是 每次调用 ajax 时我都会看到 MySQL 指令 即使 postdata 与之前的调用中的相同 我错过了什么吗 缓存
  • Code::Blocks 出现无效工具链错误

    Hello Debug uses an invalid compiler Probably the toolchain path within the compiler options is not setup correctly Skip
  • 等待句柄会释放线程获取的锁吗?

    当我有如下所示的代码时 我的问题是调用 signal WaitOne 的线程是否释放已获取的锁以供另一个线程获取锁 我认为这是一个微不足道的问题 但我尝试寻找类似的东西 却一无所获 如果有人可以阐明这一点并修改我的帖子 标题 使其更容易被将
  • 从 git 提交生成 PDF 日志

    我知道我可以使用 git log 以各种方式查看以前的提交 但我想知道这里是否有人可以推荐一些用于从 git 提交创建 PDF 或 HTML 日志的好工具 我希望能够生成类似于 Github 提交日志风格但具有不同信息的内容 如果人们对生成
  • 在数基之间转换数字

    我正在开发一个在数字基数之间进行转换的程序 例如八进制是 8 十进制是 10 字母A to Z可以被视为基数 26 我想将 A 转换为0 Z转换为25 AA 转换为27 BA 转换为53 在开始编码之前 我会在纸上进行编码 以便我了解整个过
  • 为什么 date() 不能正确地将 YYMMDDHHMM 转换为 MySQL 可接受的日期格式?

    我想要一个像这样的字符串 1511030830 YYMMDDHHMM 并创建一个 MySQL 时间戳 如下所示 2015 11 03 08 30 00 但是 当尝试这样做时 它将不起作用 string 1511030830 date dat
  • 以编程方式启动时 Appium 不会初始化驱动程序

    我正在使用 Java 和 Selenium 通过命令行初始化 Appium 以便在 Android chrome 浏览器上运行测试 然而 该过程运行无限时间 并且代码来自 DesiredCapabilities 该行没有被执行 代码 Pro
  • 列出给定类的所有内部类 - Python

    给定一个类 我如何列出它的所有inner课程 class Car some var var class Engine some other var var2 class Body another var var3 现在给出Car我希望能够列
  • Next JS在arcgis地图上的多个坐标上显示信息窗口

    下面是我的下一个 JS 代码 它显示了一个简单的 ArcGIS 地图 其中包含特定坐标上的点或标记 谁能告诉我如何在地图上显示点的弹出 信息窗口 例如我单击任意点 它将在其上打开相应的弹出窗口 import NavBar from comp
  • 从 Swift 初始化器调用方法

    假设我在 Swift 中有以下类 有明显的问题 class MyClass let myProperty String init super init self setupMyProperty func setupMyProperty my
  • 如何在 AWS Cognito 中编辑尝试更改密码的限制?

    我已经实现了更改密码功能 现在我想测试一下 但我面临着尝试的极限 我应该做什么来防止这个错误 已超出尝试次数限制 请稍后再试 我是 Cognito 团队的成员 这是不可配置的 我们确实有保护机制来防止用户滥用忘记密码的 API 这可能就是您
  • Keras 中的 add_loss 函数的用途是什么?

    目前 我偶然发现了变分自动编码器 并尝试使用 keras 让它们在 MNIST 上工作 我找到了一个教程github https github com keras team keras blob master examples variat