Keras 自定义损失函数:访问当前输入模式

2024-04-30

在 Keras(带有 Tensorflow 后端)中,当前输入模式可用于我的自定义损失函数吗?

当前输入模式被定义为用于产生预测的输入向量。例如,请考虑以下情况:X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42, shuffle=False)。那么当前输入模式是与 y_train 关联的当前 X_train 向量(在损失函数中称为 y_true)。

在设计自定义损失函数时,我打算优化/最小化需要访问当前输入模式的值,而不仅仅是当前的预测。

我已经浏览过了https://github.com/fchollet/keras/blob/master/keras/losses.py https://github.com/fchollet/keras/blob/master/keras/losses.py

我也看过《成本函数不仅仅是 y_pred、y_true? https://github.com/fchollet/keras/issues/7379"

我也熟悉以前的示例来生成定制的损失函数:

import keras.backend as K

def customLoss(y_true,y_pred):
    return K.sum(K.log(y_true) - K.log(y_pred))

想必(y_true,y_pred)在别处定义。我浏览了源代码但没有成功,我想知道我是否需要自己定义当前的输入模式,或者我的损失函数是否已经可以访问它。


您可以将损失函数包装为内部函数,并将输入张量传递给它(就像向损失函数传递附加参数时通常所做的那样)。

def custom_loss_wrapper(input_tensor):
    def custom_loss(y_true, y_pred):
        return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
    return custom_loss

input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

你可以验证一下input_tensor和损失值(主要是K.mean(input_tensor)部分)会随着不同而改变X被传递给模型。

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y)  # => 1.1974642

X *= 1000
model.test_on_batch(X, y)  # => 511.15466
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Keras 自定义损失函数:访问当前输入模式 的相关文章

随机推荐

  • 获取Spring bean的真实Class对象

    我正在使用 Spring 注入 Bean 我正在使用一些注释来注释 bean 方法 Security TransactionManagement ExceptionHanling Logging 问题是 我想创建 JUnit 测试来检查我是
  • 无法在 Zsh 中找到 Bash 的替代命令

    我将最新的 git completion bash 放入我的 zshrc 中 然后我得到 Users Masi bin shells git git completion bash 2116 command not found comple
  • 如何使用 Linq 查询 Azure 存储表?

    我不确定具体在哪里 但我在某个地方有错误的想法 我首先尝试使用 linq 查询 azure 存储表 但我无法弄清楚它是如何完成的 通过查看各种来源 我得到以下信息 List
  • 出现错误 /usr/bin/env: 节点:权限被拒绝

    我已经在我的服务器 Centos 上完成了 ODOO v9 安装 一切都已安装成功 登录页面也可以正常工作 但登录后我收到一个包含以下错误的页面 usr bin env node Permission Denied 我尝试更改权限 但我的问
  • 法线在 openGL 中表现得很奇怪

    我一直在为 openGl 编写一个 obj 加载器 几何体加载得很好 但法线总是混乱的 我尝试在两个不同的程序中导出模型 但似乎没有任何效果 据我所知 这就是将法线放入 GL TRIANGLES 的方法 glNormal3fv norm1
  • 在 C# 中通过 Sharpsvn 使用 client.status

    我想使用状态方法 但我不明白它是如何工作的 有人可以给我看一个使用示例吗 EventHandler lt SvnStatusEventArgs gt statusHandler new EventHandler
  • Rails 验证:将输入限制为特定值

    我正在寻找 Rails Way 来编写验证 将可接受的输入值限制为预定列表 就我而言 我只想接受值 5 2 2 5 和nil 然而 我认为这最好作为一个一般性问题 如何在 Rails 模型中预定义可接受的条目值列表 Thanks valid
  • WSO2 ESB 中的跟踪日志文件

    在 WSO2 ESB 中 它到底显示 wso2 esb trace log 文件什么 什么时候有用 并且 与 WSO2 ESB 中的其他典型日志文件有何不同 例如 使用 wso2 esb service log 或 wso2 esb err
  • android开发打开文件txt并返回内容[关闭]

    很难说出这里问的是什么 这个问题是含糊的 模糊的 不完整的 过于宽泛的或修辞性的 无法以目前的形式得到合理的回答 如需帮助澄清此问题以便重新打开 访问帮助中心 help reopen questions 在互联网上搜索并找不到有效的代码 如
  • 片段重新进入过渡不起作用。需要帮助澄清各种片段转换

    我正在 RecyclerView 中的项目之间实现片段过渡动画 以及显示单击项目的详细信息的片段 换句话说 相对常见的 单击列表中的一张卡片 它会展开为详细视图 而列表的其余部分则消失 之类的事情 从 RecyclerView 项目到详细视
  • Python-错误:无法打开.png文件[重复]

    这个问题在这里已经有答案了 不确定我做错了什么 我正在遵循有关如何使用 Python 和 PyGame 制作游戏的教程 但收到错误 pygame error Couldn t open resources images dude png 我
  • 如何解释R中SVM的预测结果?

    我是 R 新手 我正在使用e1071R 中的 SVM 分类包 我使用了以下代码 data lt loadNumerical model lt svm data ncol data data ncol data gamma 10 print
  • 如何区分类实现中两个协议的相同方法名称?

    我有两个协议 protocol P1 void printP1 void printCommon end protocol P2 void printP2 void printCommon end 现在 我在一个类中实现这两个协议 inte
  • 让 csv.Sniffer 使用带引号的值

    我正在尝试使用python 的 CSV 嗅探工具 https docs python org 3 library csv html csv Sniffer正如许多 StackOverflow 答案中所建议的那样 猜测给定的 CSV 文件是否
  • Apache Ignite - 选择查询返回 0 条记录,但数据存在于缓存中

    我们使用 Apache Ignite 2 9 0 它是一个具有 Zookeeper 发现功能的 5 节点集群 我们通过从 Intellij 执行 DDL 语句在 Ignite 中创建表 然后我们可以通过从 Intellij 本身运行选择查询
  • 在 AppEngine 上的 iText 中添加新字体时出现 NoClassDefFoundError

    我有一个 appengine java 项目 其中包括有时创建 pdf pdf 文档有我试图包含的特殊字体 BaseFont bf BaseFont createFont resources AlexBrush Regular ttf Ba
  • SQLAlchemy 的数据类默认不填充 postgres 数据库

    我在用dataclasses与 SQLAlchemy 经典映射范例相结合 当我定义一个dataclass与默认值相结合int and strSQLAlchemy 不会填充字段int and strs 但它确实填充了List and date
  • 如何实施刷新令牌轮换?

    如果我正确理解了刷新令牌轮换 这意味着每次我们请求新的访问令牌时 我们也会获得一个新的刷新令牌 如果多次使用刷新令牌 我们会使某个用户之前使用的所有刷新令牌失效 并且用户必须再次执行身份验证过程 这是否意味着我们需要将所有刷新令牌 所有旧的
  • python - 创建具有多种颜色的图像并添加文本

    我正在尝试用 python 中的一些文本创建图像 例如 import PIL from PIL import ImageFont from PIL import Image from PIL import ImageDraw font Im
  • Keras 自定义损失函数:访问当前输入模式

    在 Keras 带有 Tensorflow 后端 中 当前输入模式可用于我的自定义损失函数吗 当前输入模式被定义为用于产生预测的输入向量 例如 请考虑以下情况 X train X test y train y test train test