实现多类骰子损失函数

2023-12-23

我正在使用 UNet 进行多类分割。我对模型的输入是HxWxC我的输出是

outputs = layers.Conv2D(n_classes, (1, 1), activation='sigmoid')(decoder0)

Using SparseCategoricalCrossentropy我可以很好地训练网络。现在我还想尝试骰子系数作为损失函数。实施如下,

def dice_loss(y_true, y_pred, smooth=1e-6):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.math.sigmoid(y_pred)

    numerator = 2 * tf.reduce_sum(y_true * y_pred) + smooth
    denominator = tf.reduce_sum(y_true + y_pred) + smooth

    return 1 - numerator / denominator

然而,我实际上得到的是越来越多的损失,而不是减少的损失。我检查了多个来源,但我找到的所有材料都使用骰子损失进行二元分类而不是多类分类。所以我的问题是实施有问题。


问题是你的骰子损失并没有解决你拥有的类的数量,而是假设二进制情况,所以它可能解释你的损失的增加。

您应该实现涵盖所有类别的广义骰子损失,并返回所有类别的值。

像下面这样:

def dice_coef_9cat(y_true, y_pred, smooth=1e-7):
    '''
    Dice coefficient for 10 categories. Ignores background pixel label 0
    Pass to model as metric during compile statement
    '''
    y_true_f = K.flatten(K.one_hot(K.cast(y_true, 'int32'), num_classes=10)[...,1:])
    y_pred_f = K.flatten(y_pred[...,1:])
    intersect = K.sum(y_true_f * y_pred_f, axis=-1)
    denom = K.sum(y_true_f + y_pred_f, axis=-1)
    return K.mean((2. * intersect / (denom + smooth)))

def dice_coef_9cat_loss(y_true, y_pred):
    '''
    Dice loss to minimize. Pass to model as loss during compile statement
    '''
    return 1 - dice_coef_9cat(y_true, y_pred)

此片段摘自https://github.com/keras-team/keras/issues/9395#issuecomment-370971561 https://github.com/keras-team/keras/issues/9395#issuecomment-370971561

这是针对 9 个类别的,您应该根据您拥有的类别数量进行调整。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

实现多类骰子损失函数 的相关文章

随机推荐

  • 在 Delphi 2010 中将字符串写入 TFileStream

    我有 Delphi 2007 代码 如下所示 procedure WriteString Stream TFileStream var SourceBuffer PChar s string begin StrPCopy SourceBuf
  • Laravel:Stripe 未提供 API 密钥。 (提示:使用 Stripe::setApiKey() 设置 API 密钥

    是的 我知道 那里有一个完全相同的问题 但 解决方案 没有得到批准或指定 所以 1 我通过 php Composer phar require stripe etc 安装了 stripe 库 v 3 0 并且安装正常 否则我实际上不会收到该
  • 多个水平容器上的延迟加载

    我正在使用延迟加载 jQuery 插件 http www appelsiini net projects lazyload http www appelsiini net projects lazyload 我的问题是 是否可以有多个滚动容
  • Python模拟对象实例化

    使用Python 2 7和模拟库 如何使用模拟测试某些修补对象是否已使用某些特定参数进行初始化 这里有一些示例代码和伪代码 单元测试 py import mock mock patch mylib SomeObject def test m
  • 在“应用”中使用数据框列名称作为图表标签

    我想创建一系列 x y 散点图 其中 y 始终是相同的变量 x 是我想要检查它们是否相关的变量 作为一个例子 让我们使用mtcars数据集 我对 R 比较陌生 但正在进步 下面的代码有效 列表图表包含所有图表 除了 X 轴显示为 x 我希望
  • 如何在 BigDecimal 上使用 >、=、< 等比较运算符

    我有一个域类unitPrice set as BigDecimal数据类型 现在我正在尝试创建一种方法来比较价格 但似乎我不能在其中使用比较运算符BigDecimal数据类型 我必须更改数据类型还是有其他方法 简而言之 firstBigDe
  • Gunicorn:没有名为“wsgi”的模块

    我有一个项目设置为使用 docker 一台一台机器运行 即 ubuntu 我一直运行良好 但最近我尝试在我的 Windows 笔记本电脑上运行它 并收到 ModuleNotFoundError 2018 01 05 20 31 46 000
  • 当库使用模板(泛型)时,是否可以使用 Rust 中的 C++ 库?

    当库 例如Boost http www boost org 使用模板 泛型 Yes 但也可能不是实际的 D 编程语言是极少数提供一定程度的 C 互操作性的语言之一 你可以阅读更多相关内容dlang https dlang org spec
  • Swift NSTimer 在后台运行

    我遇到了很多关于如何在堆栈或其他地方在后台处理 NSTimer 的问题 我已经尝试了所有实际上有意义的选项之一 当应用程序进入后台时停止计时器 NSNotificationCenter defaultCenter addObserver s
  • C++11:如何获取指针或迭代器指向的类型?

    更具体地说 假设我正在写template
  • C++ 读取字符并创建数组

    如何从文件中读取一行字符 首先 程序从文件中读取一个整数 该数字表示下一步要读入多少个字符 下一步读取字符并将它们存储在数组中 那么我如何创建 char 变量 以便我可以正确读取 Michael 的字符并将它们显示在数组中 file txt
  • 将 BuildKit 与 Docker 结合使用时,如何查看 RUN 命令的输出?

    构建 Docker 镜像时DOCKER BUILDKIT 1 有一个非常酷的进度指示器 但没有命令输出 如何查看命令输出来调试我的构建 你有没有尝试过 progress plain Example FROM alpine RUN ps au
  • 将 Android 应用缩放到不同的屏幕尺寸

    所以我正在努力将我的应用程序扩展到不同的屏幕尺寸 目前它针对 10 1 英寸屏幕进行了优化 但我正在努力让它在具有 7 英寸屏幕的 kindle fire 上运行 我只使用相对布局 到目前为止 我的背景可以完美缩放 但背景顶部的图像按钮无法
  • MySQLi 和 PDO 哪种方法更安全[关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心 help reopen questi
  • 在 Windows 2012 R2 上运行时,如何让 MSI 返回正确的 VersionNT 值?

    当我在 Windows 2012 R2 计算机 RTM 内部版本 9600 上运行 MSI 时 VersionNT 属性设置为 602 而不是 603 如果 602 实际上是正确的操作系统版本 那么如何在安装时以编程方式区分 Windows
  • 如何让 TeamCity 使用 MSTest 运行测试?

    我正在尝试弄清楚如何让 TeamCity 运行我的 MSTest 我使用以下参数设置了构建步骤 MSTest exe 的路径 system MSTest 10 0 列出汇编文件 项目 Metadude Tests bin Debug Met
  • C++11 Lambda 表达式作为回调函数

    是否有任何 C GUI 工具包支持将回调函数定义为 C 11 lambda 表达式 我相信这是使用 C 至少与 C 相比 编写基于 GUI 的程序的独特优点 对于采用 lambda 表达式作为参数的函数 我应该使用什么类型签名以及它们如何支
  • 角度绑定到带有空格的方括号表示法属性

    是否可以使用访问属性的方括号表示法绑定到角度属性 例如 使用伪代码
  • MongoDB 索引:多个单字段与单个复合索引?

    我有一个地理空间 时间数据的集合 其中包含一些附加属性 我将在地图上显示它们 目前 该集合已包含数百万份文档 并且会随着时间的推移而不断增长 每个文档都有以下字段 位置 geojson 对象 日期 日期对象 缩放级别 int32 条目类型
  • 实现多类骰子损失函数

    我正在使用 UNet 进行多类分割 我对模型的输入是HxWxC我的输出是 outputs layers Conv2D n classes 1 1 activation sigmoid decoder0 Using SparseCategor