Keras 如何处理多标签分类?

2023-12-08

我不确定如何解释 Keras 在以下情况下的默认行为:

我的 Y(基本事实)是使用 scikit-learn 设置的MultilabelBinarizer().

因此,举一个随机的例子,我的一排y列是 one-hot 编码的,如下所示:[0,0,0,1,0,1,0,0,0,0,1].

所以我有 11 类可以预测,并且不止一类可以是真实的;因此问题的多标签性质。该特定样本有三个标签。

我像处理非多标签问题(一切照常)一样训练模型,并且没有收到任何错误。

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.optimizers import SGD

model = Sequential()
model.add(Dense(5000, activation='relu', input_dim=X_train.shape[1]))
model.add(Dropout(0.1))
model.add(Dense(600, activation='relu'))
model.add(Dropout(0.1))
model.add(Dense(y_train.shape[1], activation='softmax'))

sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',
              optimizer=sgd,
              metrics=['accuracy',])

model.fit(X_train, y_train,epochs=5,batch_size=2000)

score = model.evaluate(X_test, y_test, batch_size=2000)
score

当 Keras 遇到我的时会做什么y_train并看到它是“多”单热编码的,这意味着每一行中存在多个“一”y_train?基本上,Keras 会自动执行多标签分类吗?对评分指标的解释有什么不同吗?


In short

不要使用softmax.

Use sigmoid用于激活输出层。

Use binary_crossentropy为损失函数。

Use predict进行评估。

Why

In softmax当一个标签的分数增加时,所有其他标签的分数都会降低(这是一种概率分布)。当你有多个标签时你不希望这样。

完整代码

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation
from tensorflow.keras.optimizers import SGD

model = Sequential()
model.add(Dense(5000, activation='relu', input_dim=X_train.shape[1]))
model.add(Dropout(0.1))
model.add(Dense(600, activation='relu'))
model.add(Dropout(0.1))
model.add(Dense(y_train.shape[1], activation='sigmoid'))

sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='binary_crossentropy',
              optimizer=sgd)

model.fit(X_train, y_train, epochs=5, batch_size=2000)

preds = model.predict(X_test)
preds[preds>=0.5] = 1
preds[preds<0.5] = 0
# score = compare preds and y_test
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Keras 如何处理多标签分类? 的相关文章

随机推荐

  • IntelliJ 问题 -> 无法创建名为“Main”的类

    标题说明了我的问题 我收到此错误消息 无法创建类无法解析模板 Class 错误信息 选定的类文件名 Main java 映射到非 java 文件类型 通过 TextMate 捆绑包支持的文件 有人对我如何解决这个问题有任何想法吗 请检查文件
  • 拆分字符串列值

    acctcode primekey groupby lt columns WDS 1 NULL lt values varchar FDS 2 NULL IRN 3 NULL SUM 4 1 2 3 STL 5 NULL WTR 6 NUL
  • 扩展 Asp.NET MVC3 控制器类

    我是一位经验丰富的 NET 程序员 也是一位使用 PHP 的 MVC 程序员 现在我是 MVC3 的新手 并尝试在其上构建我的第一个作品 因此我正在处理一些问题 对于初学者来说 如何扩展控制器类 有人可以指出我应该实施的指南 方法列表吗 T
  • 无法释放 C 中的 const 指针

    我怎样才能释放一个const char 我使用分配新内存malloc 当我尝试释放它时 我总是收到错误 不兼容的指针类型 导致此问题的代码类似于 char name Arnold const char str const char mall
  • Android 获取当前时间戳?

    我想像这样获取当前时间戳 1320917972 int time int System currentTimeMillis Timestamp tsTemp new Timestamp time String ts tsTemp toStr
  • Jenkins:根据相同 Jenkins 作业中的每个构建步骤结果发送电子邮件

    我只是想知道如何发送电子邮件电子邮件分机插件基于相同 Jenkins 作业的每个构建步骤结果 这是我的场景 我的 Jenkins 工作有 3 个构建步骤 构建步骤1 Pull latest code from github and Buil
  • 如何从 C++ 调用 fortran 例程?

    我希望从我的 C 代码中调用 fortran 例程 cbesj f 如何实现此目的 以下是我已完成的步骤 从 netlib amos 网页下载 cbesj f 以及依赖项 http www netlib org cgi bin netlib
  • 自动完成建议列表的 z-index 错误,我该如何更改?

    似乎我的自动完成列表的 z index 比我网站的某些元素低 所以它暴露不足 我应该编辑什么类 使用editCSS我播种这些类 并添加 我网站的z索引 但很少有不影响的是1 ui corner all ui menu item ingred
  • 如何打印第三列到最后一列?

    我正在尝试从 DbgView 日志文件中删除前两列 我对其中不感兴趣 我似乎找不到从第 3 列开始打印直到行尾的示例 请注意 每行都有可变数量的列 或更简单的解决方案 cut f 3 INPUTFILE只需添加正确的分隔符 d 即可获得相同
  • JTable 中的列的多个单元格渲染器?

    假设我有以下 JTable 按下按钮后就会显示 Name True Hello World False Foo Bar True Foo False Bar 我想渲染那些单元格最初对于 JCheckBox 来说是正确的 并且所有单元格都是最
  • MonoTouch.Dialog 崩溃

    我有一个小型测试应用程序 它仅在 3 个页面之间循环 这是应用程序委托 public override bool FinishedLaunching UIApplication app NSDictionary options sessio
  • 如何从嵌套函数内部访问 Stimulus JS 控制器方法?

    我有一个 Stimulus 控制器 其中有一个 setSegments 函数 然后在 connect 方法中使用以下代码 connect const options overview container document getElemen
  • 十六进制到二进制转换

    我已通过十六进制转换器将 jpeg 文件转换为十六进制代码 现在如何将该十六进制转换为二进制并另存为Jpeg磁盘上的文件 Like var 声明为十六进制代码 然后将该 var 十六进制代码转换为二进制并保存在磁盘上 Edit Var my
  • 如何使用X509使用JDBC连接MySQL?

    我已经设置了 MySQL 社区服务器 5 1 数据库服务器 我已经设置了 SSL 创建了证书等 我创建了一个具有 REQUIRES X509 属性的用户 我可以使用命令行客户端 mysql 使用此用户进行连接 并且 status 命令显示
  • 请解释一下此电子邮件验证正则表达式:[关闭]

    很难说出这里问的是什么 这个问题模棱两可 含糊不清 不完整 过于宽泛或言辞激烈 无法以目前的形式合理回答 如需帮助澄清此问题以便重新打开 访问帮助中心 我有这个脚本使用正则表达式来检查表单字段是否包含有效的电子邮件地址 请从声明中解释一下
  • Firebase 安全规则 - Auth 生成的 UID 是否应该保密? [复制]

    这个问题在这里已经有答案了 我一直在阅读 Firebase 实时数据库安全规则指南 https firebase google com docs database security 我有点困惑是否应该将 Firebase Auth 生成的
  • 如何将 Tensorflow BatchNormalization 与 GradientTape 结合使用?

    假设我们有一个使用 BatchNormalization 的简单 Keras 模型 model tf keras Sequential tf keras layers InputLayer input shape 1 tf keras la
  • 基于 gnu readline 的节点 shell

    是否有一个在内部使用 gnu readline 的 Node 外壳 As you know node shell sucks in 2 ways among others It doesn t have search for history
  • 是否可以将鼠标光标放在元素后面或者鼠标光标是否有 z 索引?

    当鼠标悬停在某个元素上时 我想用自定义图像替换鼠标光标 我通过首先关闭鼠标光标来做到这一点 cursor none 当它悬停在元素上时 然后我读出悬停元素上的光标位置 并将图形的 css 位置设置为光标位置并稍微偏移 以便鼠标光标不在图形上
  • Keras 如何处理多标签分类?

    我不确定如何解释 Keras 在以下情况下的默认行为 我的 Y 基本事实 是使用 scikit learn 设置的MultilabelBinarizer 因此 举一个随机的例子 我的一排y列是 one hot 编码的 如下所示 0 0 0