从多类分类算法输出前 2 个类

2024-04-14

我正在研究文本的多类分类问题,其中我有很多不同的类(15+)。 我训练了一个 Linearsvc svm 方法(方法只是示例)。 但它只输出概率最高的单个类,有没有一种算法可以同时输出两个类

我正在使用的示例代码:

from sklearn.svm import LinearSVC
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import TfidfVectorizer,CountVectorizer
count_vect = CountVectorizer(max_df=.9,min_df=.002,  
                             encoding='latin-1', 
                             ngram_range=(1, 3))
X_train_counts = count_vect.fit_transform(df_upsampled['text'])
tfidf_transformer = TfidfTransformer(sublinear_tf=True,norm='l2')
X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)
clf = LinearSVC().fit(X_train_tfidf, df_upsampled['reason'])
y_pred = model.predict(X_test)

电流输出:

    source  user   time    text         reason
0   hi      neha    0      0:neha:hi       1
1   there   ram     1      1:ram:there     1
2   ball    neha    2      2:neha:ball     3
3   item    neha    3      3:neha:item     6
4   go there ram    4      4:ram:go there  7
5   kk       ram    5      5:ram:kk        1
6   hshs    neha    6      6:neha:hshs     2
7   ggsgs   neha    7      7:neha:ggsgs    15

期望的输出:

    source  user   time    text         reason  reason2
0   hi      neha    0      0:neha:hi       1      2
1   there   ram     1      1:ram:there     1      6
2   ball    neha    2      2:neha:ball     3      7
3   item    neha    3      3:neha:item     6      4
4   go there ram    4      4:ram:go there  7      9
5   kk       ram    5      5:ram:kk        1      2
6   hshs    neha    6      6:neha:hshs     2      3
7   ggsgs   neha    7      7:neha:ggsgs    15     1

如果我只在一列中获得输出,那是可以的,因为我可以从中拆分并生成两列。


LinearSVC不提供predict_proba但它提供了decision_function它给出了距超平面的有符号距离。

来自文档:

决策函数(自我,X):

预测样本的置信度分数。

样本的置信度分数是该样本到超平面的有符号距离。

根据@warped 评论,

我们可以用decision_function输出,找到顶部n从模型预测类别。

import pandas as pd 
from sklearn.datasets import make_classification
from sklearn.svm import LinearSVC
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline

X, y = make_classification(n_samples=1000, 
                           n_clusters_per_class=1,
                           n_informative=10,
                           n_classes=5, random_state=42)

X_train, X_test, y_train, y_test = train_test_split(X, y, 
                                                    test_size=0.2,
                                                    random_state=42)
clf = make_pipeline(StandardScaler(),
                    LinearSVC(random_state=0, tol=1e-5))
clf.fit(X, y)
top_n_classes = 2
predictions = clf.decision_function(
                    X_test).argsort()[:,-top_n_classes:][:,::-1]
pred_df = pd.DataFrame(predictions, 
                       columns= [f'{i+1}_pred' for i in range(top_n_classes)])

df = pd.DataFrame({'true_class': y_test})
df = df.assign(**pred_df)

df
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

从多类分类算法输出前 2 个类 的相关文章

随机推荐

  • 查找 javascript 数组中的空索引列表

    Javascript 有没有办法找到数组为空或不包含 x 的索引 x x 会返回类似 1 2 4 5 6 我尝试过这样的事情 empty roster findIndex obj gt Object keys obj length 0 但是
  • 如何在 _layout.cshtml 中使用 data-ng-view

    我正在尝试使用 Angular 创建 MVC 应用程序 我的应用程序有通用的页眉和页脚 所以我将其添加到 layout cshtml中 应用程序中有一些静态页面 因此我想使用 Angular 路由来加载它 这是我的 layout cshtm
  • JQuery UI Datepicker:如何添加下一年/上一年按钮

    我使用通过下拉菜单选择年份的功能 我用它来设置至少 18 岁的人的生日 到目前为止 它运行得很好 我已经使用这些参数进行了设置 datepicker datepicker changeMonth true changeYear true d
  • git diff HEAD 与 git diff --staged 之间有什么区别?

    有什么区别git diff HEAD and git diff staged 我尝试了两者 但都给出了相同的输出 假设这个输出为git status git status On branch master Changes to be com
  • 如何在android中使用xml布局进行绘制

    我正在尝试完成开发人员页面上给出的 android 示例 它提供了两种在画布上绘图的方法 第一种方法是使用名为 CustomDrawableView 的类 如下所示 public class CustomDrawableView exten
  • 如何清除 WPF Frame 控件托管的整个导航历史记录

    在 WPF 应用程序中 Frame 控件用于托管 导航页面 我想清除导航历史记录 有 NavigationService RemoveBackEntry 方法可用于清除历史记录的向后部分 但是前向导航历史又如何呢 这部分怎么清除呢 最佳实践
  • Android WebView 显示纯文本而不是 html

    首先 我想说这只是 Android 2 及更早版本上的问题 4 似乎不受影响 我没有测试 3 我有一个WebView从字符串加载 html HTML 看起来像这样 h1 Hello World h1 您可以看到 css 文件如下所示 bod
  • H2DB WITH 子句

    我正在使用以下 sql 为方法编写单元测试 WITH temptab i id i name i effective i expires i lefttag i righttag hier id hier dim id parent ite
  • iPad 弹出文本字段 - resignFirstResponder 不会关闭键盘

    我有两个文本字段电子邮件和密码 当字段显示在常规视图上时 以下代码工作正常 但当它们显示在弹出窗口上时 resignFirstResponder 不起作用 becomeFirstResponder 起作用 为这两个字段调用了 textFie
  • 将 C dll 代码编组为 C#

    我在 dll 中有以下 C 代码签名 extern declspec dllexport unsigned char funct name int w int h char enc int len unsigned char text in
  • 当使用 apply() 和 call() 方法很容易继承时,为什么人们在 JavaScript 中使用原型?

    形状由矩形继承 这种继承可以通过多种方法来完成 这里我使用了apply 和call 当子类的draw方法被调用时 从该方法中再次调用基类的draw方法 我通过两种方式完成了这件事 一种是制作基类的原型绘制方法 另一种是使用 apply 和
  • Google 应用引擎 - 如何禁用缓存

    所以一些背景 我有一个在谷歌应用程序引擎上运行的nodeJS api 默认情况下 应用程序引擎会将我的所有获取请求缓存 10 分钟 我将 cloudflare 用于我的 API 因为这允许我在需要时从缓存中删除特定项目 您可以想象这会引起一
  • Swift iOS - 标签集合视图

    我正在编写我的第一个 iOS 应用程序 我只想回答最知名的解决方案是什么 这是简单的标签收集 我已经在互联网上查看过 但一无所获 我认为最好的方法可能是制作我自己的按钮结构 这是我想要实现的目标 有时你需要自己做 import UIKit
  • 在 Visual Studio Code 中自动导入以进行 React-Native 开发 [关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 有扩展名吗VS Code这使得自动导入 for 反应本机组件 例如 当我打字时
  • 异步映射中的同步部分

    我有一个大的 IO 函数 它将持续从文件夹加载数据 对数据执行纯计算 然后写回 我正在多个文件夹上并行运行此函数 mapConcurrently iofun folderList from http hackage haskell org
  • 使用 AutoMapper 进行集合的多态映射

    TL DR 我在多态映射方面遇到了麻烦 我已经制作了一个 github 存储库 其中包含一个测试套件来说明我的问题 请在这里找到它 回购链接 https github com 780Farva AutoMapperInquiry 我正在努力
  • Neo4j 客户端使用“DateTime?”展开

    我目前正在尝试展开具有 日期时间 的 TravelEdges 列表 但我不断收到以下错误 CypherTypeException 类型不匹配 需要一个地图 但是字符串 2018 05 21T08 38 00 我目前正在使用最新版本的 neo
  • Azure SQL 数据库连接问题 - 连接太多?

    我有一个最近推出的白标网站 同一网站的多个版本 目前还没有大量流量 主要是机器人 但每天可能有 800 个用户 它托管在 Azure 上 具有 Azure 数据库以及位于非 Azure 服务器上的管理面板 两个站点都连接到同一 Azure
  • 使用git打开文件的命令

    我将 Sublime Text 作为 git 中的默认编辑器 并且它有效 git config edit在 Sublime Text 中打开配置文件 很棒 我的问题 打开命令是什么index html or style css从项目目录内部
  • 从多类分类算法输出前 2 个类

    我正在研究文本的多类分类问题 其中我有很多不同的类 15 我训练了一个 Linearsvc svm 方法 方法只是示例 但它只输出概率最高的单个类 有没有一种算法可以同时输出两个类 我正在使用的示例代码 from sklearn svm i