为 Keras 逐个元素编写自定义损失函数

2024-02-16

我是机器学习、Python 和 Tensorflow 的新手。我习惯用 C++ 或 C# 编写代码,很难使用 tf.backend。 我正在尝试为 LSTM 网络编写一个自定义损失函数,尝试预测时间序列的下一个元素是正数还是负数。我的代码与binary_crossentropy损失函数一起运行得很好。我现在想改进我的网络,使其具有一个损失函数,如果预测概率大于 0.5,则添加下一个时间序列元素的值;如果概率小于或等于 0.5,则减去它。 我尝试过这样的事情:

def customLossFunction(y_true, y_pred):
    temp = 0.0
    for i in range(0, len(y_true)):
        if(y_pred[i] > 0):
            temp += y_true[i]
        else:
            temp -= y_true[i]
    return temp

显然,尺寸是错误的,但由于我在调试时无法进入我的函数,因此很难在这里掌握尺寸。 您能告诉我是否可以使用逐元素函数吗?如果是,怎么办?如果没有,你能帮我解决 tf.backend 问题吗? 多谢


从 keras 后端函数中,您可以得到以下函数greater您可以使用:

import keras.backend as K

def customLossFunction(yTrue,yPred)

    greater = K.greater(yPred,0.5)
    greater = K.cast(greater,K.floatx()) #has zeros and ones
    multiply = (2*greater) - 1 #has -1 and 1

    modifiedTrue = multiply * yTrue

    #here, it's important to know which dimension you want to sum
    return K.sum(modifiedTrue, axis=?)

The axis应根据您想要求和的内容使用参数。

axis=0 -> batch or sample dimension (number of sequences)     
axis=1 -> time steps dimension (if you're using return_sequences = True until the end)     
axis=2 -> predictions for each step 

现在,如果您只有一个 2D 目标:

axis=0 -> batch or sample dimension (number of sequences)
axis=1 -> predictions for each sequence

如果您只是想对每个序列的所有内容求和,那么就不要放置 axis 参数。

关于此功能的重要说明:

因为它只包含来自的值yTrue,它不能反向传播来改变权重。这将导致“不支持任何值”错误或非常类似的错误。

虽然yPred(与模型权重相关的)在函数中使用,它仅用于获取 true x false 条件,该条件是不可微分的。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

为 Keras 逐个元素编写自定义损失函数 的相关文章

随机推荐

  • 用户窗体根据屏幕分辨率调整大小

    我有一个 Excel 用户表单 我想在打开时调整大小以适应屏幕分辨率 我通过得到高度和宽度Application Height and Application Width 通常使用这两个参数和以下代码 应该可以解决问题 Me Top App
  • 以编程方式最大化窗口并防止用户更改窗口状态

    如何以编程方式最大化窗口 以便窗口一旦打开就无法调整大小 达到最大化状态 例如 最大化 Internet Explorer 并查看它 我将 FormWindowState 属性设置为 this WindowState FormWindowS
  • 检测视图中的任何触摸(iPhone SDK)

    我目前正在使用 void touchesBegan NSSet touches withEvent UIEvent event void touchesEnded NSSet touches withEvent UIEvent event
  • 在 Electron 中使用量角器

    我正在尝试为我运行的应用程序设置单元测试和 e2e 测试Electron http electron atom io using 量角器 https angular github io protractor 我参考了很多不同的帖子 this
  • Rails 渲染路线路径

    我对 Rails 还很陌生 很难理解 Rails 中路径系统的工作原理 在我的routes rb中 我创建了一个用于注册的别名 match signup gt user new resource user controller gt use
  • 数据路径 '''' 不应具有附加属性 (es5BrowserSupport)

    尝试在 Angular 中开始 在 CLI 中创建项目后 我尝试使用两者打开项目ng serve o and npm start但我收到以下错误 Schema validation failed with the following err
  • 线程与并行,它们有何不同?

    线程和并行有什么区别 哪一个比另一个有优势 Daniel Moth 我的前同事 线程 并发与并行 http www danielmoth com Blog 2008 11 threadingconcurrency vs parallelis
  • 为什么 Firefox 忽略基于范围查询的缓存控制?

    Web 服务器能够将媒体 本例中为音频 传输到浏览器 浏览器使用 HTML5 控件来播放媒体 然而 我发现 Firefox 正在缓存媒体 尽管我 相信我 明确告诉它不要这样做 我有预感 它与 206 部分内容响应有关 因为带有完整 200
  • 在分页期间获取SQL Server中记录总数的有效方法

    当查询 sql server 中的表时 我试图仅获取当前页的记录 但是 我需要为特定查询返回的记录总数来计算页数 如何在不编写另一个查询来计算记录的情况下有效地执行此操作 WITH allentities AS SELECT Row num
  • 错误 CS0106:修饰符“private”对于此项无效 Unity 中的 C# 错误

    我不断收到此错误 CS0106 修饰符 私有 对此项目无效 并且需要一些帮助 我正在尝试为我的游戏制作一个随机对象生成器 但由于我仍然是新手编码器 我似乎不知道如何解决这个问题 你能帮忙的话 我会很高兴 这是我使用的代码 using Sys
  • 谷歌地图可以根据小时分钟秒绘制点吗

    我正在尝试绘制以时分秒秒格式提供给我的 GPS 数据 GLatLng 会采用这种形式吗 还是我需要先转换它 很难在互联网上找到与此相关的任何内容 如果可以采用这种格式 我们将不胜感激 据我所知它不接受这种格式 但转换它真的很容易 只需计算一
  • Python 中的最佳 ETL 包

    我有两个用例 从 Oracle PostgreSQL Redshift S3 CSV 提取 转换并加载到我自己的 Redshift 集群 安排作业每天 每周运行 INSERT TABLE 或 INSERT NONE 选项更好 我目前正在使用
  • 如何在android中为不同类型的标记指定onMarkerclick()

    在谷歌地图上 我为不同的目的放置不同颜色的标记 在这里我想要每个标记的 onMarkerclick 具有不同的功能 例如 所有绿色标记 如何为此创造条件 这是我创建一组的代码标记数 Override public void onMapLon
  • 如何为 UIFontMetrics 指定最小 UIContentSizeCategory?

    我有一种基于动态类型创建自动缩放字体的方法 如下所示 extension UIFont public static func getAutoScalingFont fontName String textStyle UIFont TextS
  • 使用 conda env 的 apache-airflow systemd 文件

    我正在尝试奔跑apache airflow在 Ubuntu 16 04 文件上 使用 systemd 我大致跟着本教程 https github com hgrif airflow tutorial并安装 设置以下内容 Miniconda
  • Android 中打开 pdf 文件的限制

    我正在尝试从 Android 应用程序中打开一些 pdf 文件 我正在使用 Intent 来执行此操作 Intent intent new Intent intent setDataAndType Uri parse url applica
  • 是否可以声明Supplier需要抛出异常?

    所以我尝试重构以下代码 Returns the duration from the config file return The duration private Duration durationFromConfig try return
  • Webpack nodejs fs.readFile 不是函数

    我有一个 webpack 配置 例如 var path require path module exports entry index js output path path join dirname static filename bun
  • 如何发出引发 RecordNotFound 的“查找”或“位置”

    当我调用带有 id 的查找时 它会变成目标查找 并会抛出错误 RecordNotFound Foo Bar find 123 RecordNotFound if no Bar with id 123 exists 但是当我有条件地调用它时
  • 为 Keras 逐个元素编写自定义损失函数

    我是机器学习 Python 和 Tensorflow 的新手 我习惯用 C 或 C 编写代码 很难使用 tf backend 我正在尝试为 LSTM 网络编写一个自定义损失函数 尝试预测时间序列的下一个元素是正数还是负数 我的代码与bina