如何在 Keras 中定义自定义精度以忽略具有特定金色标签的样本?

2024-03-23

我想在 Keras 中编写一个自定义指标(我正在使用张量流后端),相当于categorical_accuracy,但是具有特定金色标签的样本的输出(在我的例子中是 0,来自 y_true)必须被忽略。例如,如果我的输出是:

预测 1 - 金 0

预测 1 - 金 1

准确度将为 1,因为必须忽略带有金色标签 0 的样本。也就是说,我编写的函数(并且没有给出预期结果)是:

def my_accuracy(y_true, y_pred):

    mask = K.any(K.not_equal(K.argmax(y_true, axis=-1), 0), axis=-1, keepdims=True)

    masked_y_true = y_true*K.cast(mask, K.dtype(y_true))
    masked_y_pred = y_pred*K.cast(mask, K.dtype(y_pred))

    return keras.metrics.categorical_accuracy(masked_y_true, masked_y_pred)`

任何帮助表示赞赏,谢谢!


你可以尝试这个方法:

def ignore_accuracy_of_class(class_to_ignore=0):
    def ignore_acc(y_true, y_pred):
        y_true_class = K.argmax(y_true, axis=-1)
        y_pred_class = K.argmax(y_pred, axis=-1)

        ignore_mask = K.cast(K.not_equal(y_pred_class, class_to_ignore), 'int32')
        matches = K.cast(K.equal(y_true_class, y_pred_class), 'int32') * ignore_mask
        accuracy = K.sum(matches) / K.maximum(K.sum(ignore_mask), 1)
        return accuracy

    return ignore_acc
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在 Keras 中定义自定义精度以忽略具有特定金色标签的样本? 的相关文章

  • 如何解释和转换 Keras 分类器的预测值?

    我正在训练我的 Keras 模型来预测 使用提供的数据参数 它是否会射击 并且它将以 0 表示否 1 表示是的方式表示 然而 当我尝试预测它时 我得到的是浮点值 我尝试使用与训练数据完全相同的数据来获取 1 但它不起作用 我使用下面的数据尝
  • 如何在网络工作者中运行handpose tfjs模型

    我想使用网络摄像头获取帧并运行张量流模型 handpose 来估计手部可见度 众所周知 手势模型有点慢 所以我尝试将估计转移到网络工作人员 问题是HTMLVideoElement object could not be cloned 我需要
  • 如何在 Tensorflow 上测试自己的图像到 Cifar-10 教程?

    我训练了 Tensorflow Cifar10 模型 我想为其提供自己的单个图像 32 32 jpg png 我想将标签和每个标签的概率视为输出 但我对此遇到了一些麻烦 搜索堆栈溢出后 我发现了一些帖子this https stackove
  • Tensorflow 中多维时间序列预测中的向量表示

    我有一个大型数据集 约 3000 万个数据点 具有 5 个特征 我已使用 K 均值将其减少到 200 000 个集群 数据是大约 150 000 个时间步长的时间序列 我想要训练模型的数据是每个时间步上特定簇的存在 预测模型的目的是生成一个
  • 如何在运行 Tensorflow 推理会话之前批处理多个视频帧

    我做了一个项目 基本上使用谷歌对象检测 API 和张量流 我所做的就是使用预先训练的模型进行推理 这意味着实时对象检测 其中输入是网络摄像头的视频流或使用 OpenCV 的类似内容 现在我得到了相当不错的性能结果 但我想进一步提高 FPS
  • 在Tensorflow中,sampled_softmax_loss和softmax_cross_entropy_with_logits有什么区别

    在张量流中 有一些方法称为softmax cross entropy with logits https www tensorflow org versions master api docs python tf nn softmax cr
  • 如何用tensorflow计算AUC?

    我已经使用 Tensorflow 构建了一个二元分类器 现在我想使用 AUC 和准确性来评估分类器 就准确性而言 我可以轻松地这样做 X tf placeholder float None n input y tf placeholder
  • 如何将 .pb 文件转换为 .h5。 (张量流模型到keras)

    我已经使用重新训练了我的模型tensorflow现在想使用keras以避免会话内容 我怎样才能转换 pb文件至 h5 import tensorflow as tf from tensorflow keras models import s
  • BERT - 池化输出与序列输出的第一个向量不同

    我在 Tensorflow 中使用 BERT 有一个细节我不太明白 根据文档 https tfhub dev google bert uncased L 12 H 768 A 12 1 https tfhub dev google bert
  • 计算复合损失函数各部分的梯度范数

    假设我有以下损失函数 loss a tf reduce mean my loss fn model output targets loss b tf reduce mean my other loss fn model output tar
  • 如何使用 Keras 将图像文件夹转换为 X 和 Y 批次?

    假设我有一个图像文件夹 例如 PetData Dog images Cat images 我如何将其转换为 x train y train x test y test 格式 我看到这种格式广泛用于 MNIST 数据集 如下所示 mnist
  • Keras模型拟合多项式

    我从四次多项式生成了一些数据 并希望在 Keras 中创建一个回归模型来拟合该多项式 问题是拟合后的预测似乎基本上是线性的 由于这是我第一次使用神经网络 我认为我犯了一个非常微不足道且愚蠢的错误 这是我的代码 model Sequentia
  • Keras,训练模型后如何预测?

    我正在使用 reuters example 数据集 它运行良好 我的模型已经过训练 我阅读了有关如何保存模型的信息 以便稍后加载它以再次使用 但如何使用这个保存的模型来预测新文本呢 我用吗models predict 我必须以特殊方式准备这
  • 无法构建具有 int 输入的 Keras 层

    我有一个复杂的 keras 模型 其中一层是自定义预训练层 需要 int32 作为输入 该模型作为继承自 Model 的类实现 其实现如下 class MyModel tf keras models Model def init self
  • tensorflow SavedModel - 如何迭代保存

    我正在采用新的SavedModel据我所知 API 是 未来 应该优先于tf train Saver 我想要实现的目标是每次保存一个模型N批次数 我想最多保留 20 个已保存的模型 显然我可以自己监控这一点 但如果tf train Save
  • 分布式张量流中的并行进程

    我有带有训练参数的张量流神经网络 它是代理的 策略 网络正在核心程序的主张量流会话的训练循环中进行更新 在每个训练周期结束时 我需要将该网络传递给几个并行进程 工作人员 这些进程将使用它来从代理策略与环境的交互中收集样本 我需要并行执行 因
  • AttributeError:该层从未被调用,因此没有定义的输入形状

    我尝试通过创建三个类在 TensorFlow 2 0 中构建自动编码器 Encoder Decoder 和 AutoEncoder 由于我不想手动设置输入形状 因此我尝试从编码器的 input shape 推断解码器的输出形状 import
  • softmax_cross_entropy_with_logits和loss.log_loss有什么区别?

    之间的主要区别是什么tf nn softmax cross entropy with logits and tf losses log loss 两种方法都接受 1 hot 标签和 logits 来计算分类任务的交叉熵损失 这些方法在理论上
  • 没有名为“_pywrap_tensorflow_internal”的模块

    在尝试验证tensorflow gpu的安装时 在尝试执行 import tensorflow as tf 时出现ImportError 我在 Windows 7 上使用 Quadro K620 Tensorflow 是使用 pip 安装的
  • 如何仅从源代码构建 TensorFlow lite 而不是所有 TensorFlow?

    我正在尝试使用 Edgetpu USB 加速器与 Intel ATOM 单板计算机和 C API 进行实时推理 Edgetpu 的 C API 基于 TensorFlow lite C API 我需要包含来自tensorflow lite目

随机推荐

  • OpenJPA 2.1.1 - 找不到元素“persistence”的声明

    我刚刚下载了http www apache org dyn closer cgi openejb 4 0 0 beta 1 apache tomee 1 0 0 beta 1 webprofile zip http www apache o
  • 如何发出返回引用的动态方法?

    我正在浏览 ref 返回的来龙去脉 并且在发出由 ref 返回的动态方法时遇到问题 手工制作的 lambda 表达式和现有方法按预期工作 class Widget public int Length delegate ref int Wid
  • 如何键入组织捕获的动态文件条目

    我试图弄清楚是否有某种方法可以创建动态文件名以在 emacs org mode 中捕获 z test entry file headline A date specific headline Notes prompt 是否有一些简单的方法将
  • 如何进入android studio中的文件资源管理器

    好吧 我不知道如何进入 android studio 中的文件资源管理器 我已经尝试搜索堆栈溢出 并发现了我所问的相同问题 但那里的解决方案不起作用 那么有人可以通过屏幕截图告诉我如何进入文件资源管理器吗 对于 Android Studio
  • 如何检查字符串中是否包含特定单词?

    a how are you if strpos a are false echo true 在 PHP 中 我们可以使用上面的代码来检查字符串是否包含特定单词 但是如何在 JavaScript jQuery 中执行相同的功能 你可以为此使用
  • 如何修复 flutter 上的“simple_permissions”错误?

    当我在 flutter 项目上运行包含 simple permissions 的代码时出现错误 Initializing gradle Resolving dependencies Running Gradle task assembleD
  • 散景中的多线悬停

    正如在这个问题中 多线散景和 HoverTool https stackoverflow com questions 32975709 bokeh multi line and hovertool 我发现悬停工具没有针对多线图实现 这是一个
  • 从 UIWebView 创建 PDF 文件

    void createPDFfromUIView UIView aView saveToDocumentsWithFileName NSString aFilename Creates a mutable data object for u
  • 有没有办法在 Objective-C 中使用 Swift 结构而不将它们作为类?

    我有几个简单的structs在 swift 文件中用 swift 编写 这些structs非常简单 几乎只包含字符串 struct Letter struct A static let aSome String descASome stat
  • 汇总数据框忽略重复

    我有一个数据框 其中一列中有重复的条目 我想根据该专栏总结其他专栏 我希望摘要在进行摘要时考虑每个唯一条目而不是总数 例如 在下面的数据框示例中 如果我想回答以下问题受访者中有多少人是年轻人 中年人和老年人 RefID 1 1 在总结 ag
  • 如何解决错误:预期标识符或“(”

    我正在编程的东西有问题 我一遍又一遍地收到这个错误 jharvard appliance Dropbox pset1 make mario clang ggdb3 O0 std c99 Wall Werror mario c lcs50 l
  • (obj == null) vs (null == obj)?

    我的老板说我应该使用null obj 因为它比obj null 但他不记得为什么要这样做 有什么理由使用null obj 我以某种方式感觉到了 相反 经过谷歌搜索后 我唯一发现的是 在 C 中 它可以防止您意外地在条件结构中键入 obj n
  • iPhone文档文件夹库/缓存安全问题

    我开发了一款iOS应用程序 出于安全原因 我将所有音频 视频文件下载到下的 Documents 文件夹中库 缓存 对于使用 iTunes 最终用户无法备份此视频 但有些外部软件很容易打开库 缓存并从此文件夹下载所有文件 我的问题是如何保护此
  • 如何在IntelliJ 2021.2.2中使用Lombok插件?

    我是从 C 开始接触 Java 的 我一直在努力熟悉这门语言 我正在尝试使用 IntelliJ IDEA 的 lombok 插件 但它似乎根本不适合我 这是我的IDEA无法识别 value 注释的屏幕截图 https i stack img
  • 您能否将多个不同的值类型分配给重复的 Protobuf 消息中的一个字段?

    我正在尝试对客户端进行逆向工程 该客户端将音频文件上传到服务器 然后在单独的请求中上传文件的元数据 元数据在 Protobuf 中序列化 并且使用相当简单且可读的结构 这是之后的样子protoc decode raw 1 1 title 2
  • 对 Django 模板中的相关项目进行排序

    是否可以对 DJango 模板中的一组相关项目进行排序 即 这段代码 为了清楚起见 省略了 HTML 标签 for event in eventsCollection event location for attendee in event
  • laravel 一个帐户下有多个电子邮件地址

    我的 Laravel 应用程序要求用户可以拥有多个可用于登录的电子邮件地址 我的问题是 如何允许用户在一个帐户下拥有多个电子邮件地址 我必须记住 每封电子邮件只能由一个用户使用 我的想法是为电子邮件创建一个单独的表 其中包含用户 ID 我仍
  • delphi 7 中的 utf8 解码

    我需要使用 delphi 7 将字符串从 utf8 转换为宽字符串 谁能告诉我为什么下面的代码在delphi 7中不起作用 Utf8Decode 函数的参数只是一个示例 var ws WideString begin ws Utf8Deco
  • C# 如何杀死阻塞的线程?

    我有一个线程 void threadCode object o doStuffHere o Blocking call Sometimes hangs 我这样称呼它 Thread t new Thread new ThreadStart d
  • 如何在 Keras 中定义自定义精度以忽略具有特定金色标签的样本?

    我想在 Keras 中编写一个自定义指标 我正在使用张量流后端 相当于categorical accuracy 但是具有特定金色标签的样本的输出 在我的例子中是 0 来自 y true 必须被忽略 例如 如果我的输出是 预测 1 金 0 预