解决类别不平衡:扩大对损失和 SGD 的贡献

2023-11-29

(已添加对此问题的更新。)

我是比利时根特大学的研究生;我的研究是关于深度卷积神经网络的情感识别。我正在使用Caffe实施 CNN 的框架。

最近我遇到了一个关于班级不平衡的问题。我正在使用大约 9216 个训练样本。 5% 被标记为阳性 (1),其余样本被标记为阴性 (0)。

我正在使用S形交叉熵损失层来计算损失。训练时,即使经过几个时期,损失也会减少,并且准确率极高。这是由于不平衡造成的:网络总是预测负数 (0)。(准确率和召回率均为零,支持这一说法)

为了解决这个问题,我想根据预测与真实的组合调整对损失的贡献(严厉惩罚假阴性)。我的导师/教练也建议我反向传播时使用比例因子通过随机梯度下降(sgd):该因子将与批次中的不平衡性相关。仅包含负样本的批次根本不会更新权重。

我只向 Caffe 添加了一个定制层:报告其他指标,例如精确度和召回率。我对 Caffe 代码的经验有限,但我在编写 C++ 代码方面拥有丰富的专业知识。


谁能帮助我或为我指出如何调整的正确方向S形交叉熵损失 and Sigmoid层以适应以下更改:

  1. 根据预测-真值组合(真阳性、假阳性、真阴性、假阴性)调整样本对总损失的贡献。
  2. 根据批次中的不平衡性(负数与正数)缩放由随机梯度下降执行的权重更新。

提前致谢!


Update

我已经合并了信息增益损失层正如建议的Shai。我还添加了另一个构建信息增益矩阵的自定义层H根据当前批次的不平衡情况。

目前,矩阵配置如下:

H(i, j) = 0          if i != j
H(i, j) = 1 - f(i)   if i == j (with f(i) = the frequency of class i in the batch)

我计划将来尝试不同的矩阵配置。

我已经在 10:1 不平衡的情况下对此进行了测试。结果表明网络现在正在学习有用的东西:(30个epoch后的结果)

  • 准确度约为。 〜70%(低于〜97%);
  • 精度约为。 ~20%(从 0% 上升);
  • 召回率约为~60%(从 0% 上升)。

这些数字在 20 个 epoch 左右达到,此后没有显着变化。

!!上述结果只是概念证明,它们是通过在 10:1 不平衡数据集上训练简单网络获得的。 !!


你为什么不使用信息增益损失层来补偿训练集的不平衡?

Infogain 损失是使用权重矩阵定义的H(在你的情况下是2×2)其条目的含义是

[cost of predicting 1 when gt is 0,    cost of predicting 0 when gt is 0
 cost of predicting 1 when gt is 1,    cost of predicting 0 when gt is 1]

因此,您可以设置以下条目H反映预测 0 或 1 时的误差之间的差异。

您可以找到如何定义矩阵H对于咖啡在这个线程.

关于样本权重,您可能会发现这个帖子有趣的是:它展示了如何修改带损失的 Softmax层考虑样本权重。


最近,提出了对交叉熵损失的修改林聪怡、Priya Goyal、Ross Girshick、何凯明、Piotr Dollár 用于密集物体检测的焦点损失, (ICCV 2017).
焦点损失背后的想法是根据预测该示例的相对难度(而不是基于班级规模等)为每个示例分配不同的权重。从我尝试这种损失的短暂时间来看,它感觉优于"InfogainLoss"具有班级规模的权重。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

解决类别不平衡:扩大对损失和 SGD 的贡献 的相关文章

  • 如何将异常对象序列化为 xml 字符串

    我想要类似的东西 try code here catch Exception ex stringXML Exception toXML 这样 stringXML 的值就是
  • 无法在表适配器配置属性中找到对象“Web.config”的连接“MyConnName”

    I want to change the query in table adapter but it s not opening throwing an error Configure table Adapter Failed in pro
  • Visual Studio 2015 C# 找不到参考

    我在使用 Visual Studio 2015 和 C 时遇到了问题 在同一解决方案中添加对其他项目的引用时 Visual Studio 找不到所有类 例如 我创建了一个单元测试项目 我添加了对我创建的通信项目的引用 库中有 10 个类 但
  • 使用 CMake 对 SDL 的未定义引用

    我正在使用 SDL v1 2 15 7 和 CMake 3 2 1 开发一个项目 在 h 文件中我添加了 include
  • 对无符号 8 位整数进行左移操作 [重复]

    这个问题在这里已经有答案了 我试图理解 C C 中的移位运算符 但它们给我带来了困难 我有一个无符号 8 位整数 初始化为一个值 例如 1 uint8 t x 1 根据我的理解 它在内存中的表示方式如下 0 0 0 0 0 0 0 1 现在
  • 为什么我在这段代码中不断得到两个相同的随机值? [复制]

    这个问题在这里已经有答案了 可能的重复 为什么我的随机数生成器在 C 中不是随机的 https stackoverflow com questions 932520 why does it appear that my random num
  • 二维数组的列求和

    我有一个IEnumerable
  • 本地主机和 request.Url.Authority

    我的应用程序通过 URL 中的公司标识符分隔用户 company1 app com company2 app com 我正在本地 PC 上进行测试 请求如下 company1 localhost com 但是 我的 request Url
  • Docker 不遵循构建目录中的符号链接

    我正在对一个应用程序进行 Docker 化 其中涉及通过 Clang 将二进制文件与其他 C 文件链接 我们维护二进制文件的符号链接版本 因为它们在整个代码库中使用 我的 Docker 构建目录包含整个代码库 包括源文件以及这些源文件的符号
  • 在 OpenGL 中使用不同的着色器程序?

    我必须在 OpenGL 中针对不同的对象使用两个不同的着色器程序 我发现我必须使用glUseProgram 在不同的着色器程序之间切换 但对此没有太多信息 鉴于我有两个用于不同对象的不同着色器程序 如何为每个着色器程序生成和绑定 VAO 和
  • Ajax 函数在重定向后不保存滚动位置

    正如标题所述 我编写了一个 ajax 函数 该函数应该滚动到用户在重定向之前所在的位置 我写了一个alert对于测试场景 它确实触发了 但滚动不断回到顶部 我在这里做错了什么 JavaScript ajax type GET url Adm
  • DataContractJsonSerializer 包含元素类型子类型的通用列表

    我要使用DataContractJsonSerializer用于 JSON 序列化 反序列化 我在 JSON 数组中有两种对象类型 并希望将它们都反序列化为相应的对象类型 具有以下类定义 DataContract public class
  • 不兼容的指针到字符转换

    我正在编写一个程序 将卡片值写入 52 个点字符的多维数组中 该程序是一个测试数组 稍后我将其作为函数写入主程序中 在程序中 我通过以下方式初始化 for 循环计数0通过51 我用一个switch语句调制13将卡牌值分配给数组点 但是 我收
  • 您的 C++ 程序中是否仍然存在内存分配失败问题 [关闭]

    Closed 这个问题是基于意见的 help closed questions 目前不接受答案 我正在为公司写一些指导方针 我需要回答一些棘手的问题 这一项是相当困难的 解决方案可以是 根本不跟踪 确保使用 new 分配对象 这会在分配失败
  • 为什么 char 数组需要 strcpy 而 char star 不需要 - 在 C 中使用结构

    我对这段代码有一个误解 typedef struct EXP int x char name char lastName 40 XMP main XMP a a name eaaa a lastName strcpy a lastName
  • 如何正确地将十六进制转义添加到字符串文字中?

    当你有C语言的字符串时 你可以在里面直接添加十六进制代码 char str abcde a b c d e 0x00 char str2 abc x12 x34 a b c 0x12 0x34 0x00 这两个示例在内存中都有 6 个字节
  • C# 中的快速字符串解析

    在 C 中解析字符串最快的方法是什么 目前我只是使用字符串索引 string index 并且代码运行合理 但我忍不住认为索引访问器所做的连续范围检查必须添加一些东西 所以 我想知道我应该考虑哪些技术来增强它 这些是我最初的想法 问题 使用
  • 使用 Crypto++ 和 .NET 的 CFB 模式下的 TripleDES

    我正在尝试使用 TripleDES 使用 C 应用程序获得相同的结果 该应用程序具有Crypto https www cryptopp com 和 NET应用程序使用三重DESCryptoServiceProvider https msdn
  • 频繁插入已排序的集合

    我已经对集合 列表 进行了排序 并且我需要始终保持其排序 我目前在我的集合上使用 List BinarySearch 然后在正确的位置插入元素 我也尝试过在每次插入后对列表进行排序 但性能不可接受 有没有一种解决方案可以提供更好的性能 也许
  • 清理 TPL 中的 CallContext

    根据我使用的是基于 async await 的代码还是基于 TPL 的代码 我在逻辑清理方面得到了两种不同的行为CallContext 我可以设置和清除逻辑CallContext如果我使用以下异步 等待代码 正如我所期望的 class Pr

随机推荐

  • Laravel - 使用 whereHas 获取最后一行

    我正在尝试获取上次用户活动的时间 created at 我有模型User and UserActivity 我想获取最后一个用户活动并检查该用户的最后一个活动是否是 3 天发送通知 User php
  • 忍者。对内部设置属性的奇怪拦截

    域对象 目标对象 cs public class TargetObject public virtual ChildTargetObject ChildTargetObject get return ChildTargetObjectInn
  • python: 为什么使用子进程调用 echo 会返回 WindowsError 2?

    在我的程序中 我有一个函数 runScript def runScript subprocess call echo hello 我在 Python 文档中看到过很多类似的例子 所以我认为这可行 但是 当我在程序中调用此函数时 它返回 Wi
  • 如何用C++实现“虚拟模板功能”

    首先 我已经阅读过并且现在知道虚拟模板成员函数在 C 中 还 不可能 解决方法是使类成为模板 然后在成员函数中也使用模板参数 但在 OOP 的背景下 我发现如果该类实际上是一个模板 下面的示例就不会很 自然 请注意 该代码实际上不起作用 但
  • 网页抓取 Pokemon 数据

    我试图找出每个神奇宝贝 第一代 可以学习的动作数量 我发现以下网站包含此信息 https pokemondb net pokedex game red blue yellow 这里列出了 151 个 Pokemon 对于每个 Pokemon
  • PHP mysqli_real_escape_string 返回空字符串

    如果我不使用 mysql real escape string 函数 代码可以正常工作 但该函数没有返回任何内容 我读到问题可能是由于我没有 mysql 连接 但情况似乎并非如此 请帮忙
  • shell函数中的“声明”和环境变量的范围

    考虑以下测试片段 这些是文件 declare test 的内容 function do foobar unset FOOBAR declare FOOBAR default FOOBAR override echo At end of do
  • 在 pandas 中生成唯一 ID 列

    我有一个包含三列的数据框 bins x bins y and z 我想添加一个新列unique这是该独特组合的某种 索引 bins x and bins y 以下是我想附加的示例 请注意 为了清楚起见 我对数据框进行了排序 但在此上下文中顺
  • ag-Grid - 在行悬停时显示按钮,就像 Gmail 中一样

    在 ag Grid 中 我想在悬停一行时显示操作按钮 就像在 Gmail 中一样 无论滚动位置如何 操作按钮都必须出现在网格的右端 有提到一种方法https blog ag grid com build email client with
  • 是否可以在for循环语句下完成所有ajax调用后运行代码?

    我有一个for循环语句 每个循环都会执行一个ajax调用 each arr function i v var url xml php id v ajax url url type GET dataType xml success funct
  • oracle数据库中的阿拉伯字符

    亲爱的大家 我正在努力做到以下几点 我想在我的数据库中存储阿拉伯字符 但问题是它们的存储方式是 我尝试过这些功能 msg txt convert msg txt AR8MSWIN1256 AR8ISO8859P6 但我得到了这个错误 ORA
  • Ubuntu 上的 PyXML

    我刚刚完成 Ubuntu 10 10 的全新安装 我正在尝试运行一些使用 xml 和 xpath 的脚本 我从 PyXML 内部收到错误 我认为这是一个安装错误 为了安装它 我执行了以下操作 prompt gt sudo apt get i
  • 将指向同一类型结构体成员的指针分配给另一个指向同一类型结构体的指针

    即使对我来说 这个问题听起来也非常令人困惑 而且它可能看起来很明显或已经得到解答 但我已经搜索了很多 虽然我发现了有趣的东西 但我没有找到完全适合我的问题的答案 这是一些C代码将更好地显示我的疑问 typedef struct Node s
  • MySql 和 Sql Server 是否可以有 EF 上下文?

    我有两个实体框架上下文 一个用于 MySql 一个用于 sql 如果我运行该应用程序 我会收到以下错误 The default DbConfiguration instance was used by the Entity Framewor
  • 如何在 Java 中确定路由器/网关的 IP?

    如何在 Java 中确定路由器 网关的 IP 我可以很容易地获得我的IP 我可以使用网站上的服务获取我的互联网 IP 但我如何确定我的网关的 IP 如果您了解相关方法 那么在 NET 中这有点容易 但在 Java 中如何做到这一点呢 在 W
  • 如何在右边缘水平ListView添加模糊效果以显示还有更多内容

    我想要一种让用户看到他们可以在我的水平 ListView 上水平滚动的方法 提前致谢 Edit 这是我的代码 wcyankees424 Container height 50 child Stack children
  • 扩展 CouchDB Docker 镜像

    我正在尝试扩展 CouchDB docker 镜像来预填充 CouchDB 使用初始数据库 设计文档等 为了创建一个名为db 我首先尝试了这个初始Dockerfile FROM couchdb RUN curl X PUT localhos
  • 如何对 CSV 文件中更新的行运行流式查询?

    我的文件夹中有一个 csv 文件 该文件不断更新 我需要从此 csv 文件获取输入并生成一些交易 如何从持续更新的 csv 文件中获取数据 比如说每 5 分钟一次 我尝试过以下操作 val csvDF spark readStream op
  • 如何在 JavaScript 中将“/Date(1399739515000)/”转换为日期格式?

    我的 C 服务器端代码使用 listJson 它生成如下所示的时间字符串 CaptureTime Date 1399739515000 JavaScript客户端如何转换为日期格式 你可以这样做 那么 d 将是一个 javascript 变
  • 解决类别不平衡:扩大对损失和 SGD 的贡献

    已添加对此问题的更新 我是比利时根特大学的研究生 我的研究是关于深度卷积神经网络的情感识别 我正在使用Caffe实施 CNN 的框架 最近我遇到了一个关于班级不平衡的问题 我正在使用大约 9216 个训练样本 5 被标记为阳性 1 其余样本