CatBoost 精度不平衡类

2023-12-03

我使用 CatBoostClassifier,我的类高度不平衡。我应用了一个scale_pos_weight参数来解决这个问题。在使用评估数据集(测试)进行训练时,CatBoost 在测试中显示出很高的精度。然而,当我使用预测方法对测试进行预测时,我只得到低精度分数(使用 sklearn.metrics 计算)。

我认为这可能与我应用的班级权重有关。但是,我不太明白精度分数是如何受此影响的。

params = frozendict({
    'task_type': 'CPU',
    'loss_function': 'Logloss',
    'eval_metric': 'F1', 
    'custom_metric': ['F1', 'Precision', 'Recall'],
    'iterations': 100,
    'random_seed': 20190128,
    'scale_pos_weight': 56.88657244809081,
    'learning_rate': 0.5412829495147387, 
    'depth': 7, 
    'l2_leaf_reg': 9.526905230698302
})

from catboost import CatBoostClassifier
model = cb.CatBoostClassifier(**params)
model.fit(
    X_train, y_train,
    cat_features=np.where(X_train.dtypes == np.object)[0],
    eval_set=(X_test, y_test),
    verbose=False,
    plot=True
)

model.get_best_score()
{'learn': {'Recall': 0.9243007537531925,
  'Logloss': 0.15892360013680026,
  'F1': 0.9416723809244181,
  'Precision': 0.9640191600545249},
 'validation_0': {'Recall': 0.914252301192093,
  'Logloss': 0.1714387314107052,
  'F1': 0.9357892623978286,
  'Precision': 0.9642642597943112}}

y_test_pred = model.predict(data=X_test)

from sklearn.metrics import balanced_accuracy_score, recall_score, precision_score, f1_score
print('Balanced accuracy: {:.2f}'.format(balanced_accuracy_score(y_test, y_test_pred)))
print('Precision: {:.2f}'.format(precision_score(y_test, y_test_pred)))
print('Recall: {:.2f}'.format(recall_score(y_test, y_test_pred)))
print('F1: {:.2f}'.format(f1_score(y_test, y_test_pred)))

Balanced accuracy: 0.94
Precision: 0.29
Recall: 0.91
F1: 0.44

我期望在训练时获得与 CatBoost 相同的精度,但事实并非如此。我究竟做错了什么?


Default use_weights被设定为True,这意味着为评估指标添加权重,例如Precision:use_weights=True, 为了让你自己的计算器精度和他的一样,改成Precision:use_weights=False

Also, get_best_score在迭代中给出最高分数,您需要指定在预测中使用哪个迭代。您可以设置use_best_model=True in model.fit自动选择迭代。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

CatBoost 精度不平衡类 的相关文章

  • scikit-learn RandomForestClassifier 中的子样本大小

    如何控制用于训练森林中每棵树的子样本的大小 根据 scikit learn 的文档 随机森林是一种适合许多决策的元估计器 数据集的各个子样本上的树分类器并使用 平均以提高预测准确性并控制过度拟合 子样本大小始终与原始输入样本相同 大小 但如
  • 如何反转 dropout 来补偿 dropout 的影响并保持期望值不变?

    我正在学习神经网络中的正则化deeplearning ai课程 在dropout正则化中 教授说 如果应用dropout 计算出的激活值将比不应用dropout时 测试时 更小 因此 我们需要扩展激活以使测试阶段更简单 我理解这个事实 但我
  • Scikit-learn、带有洗牌组的 GroupKFold?

    我正在使用 scikit learn 中的 StratifiedKFold 但现在我还需要观察 组 有一个很好的函数 GroupKFold 但我的数据非常依赖时间 与帮助中的相似 即周数是分组索引 但每周应该只折叠一次 假设我需要折叠 10
  • 使用 scikit 包在 Python 中绘制集群区域的边界

    这是我处理 3 个属性 x y 值 中的数据聚类的简单示例 每个样本代表其位置 x y 及其所属变量 我的代码发布在这里 x np arange 100 200 1 y np arange 100 200 1 value np random
  • kmeans 对分组数据进行聚类

    目前 我尝试在分组数据中找到簇的中心 通过使用示例数据集和问题定义 我能够创建kmeans每个组内的集群 然而 当涉及到给定组的集群的每个中心时 我不知道如何获取它们 https rdrr io cran broom man kmeans
  • Java 的支持向量机?

    我想用Java编写一个 智能监视器 它可以随时发出警报detects即将到来的性能问题 我的 Java 应用程序正在以结构化格式将数据写入日志文件
  • scikit-learn 和tensorflow 有什么区别?可以一起使用它们吗?

    对于这个问题我无法得到满意的答案 据我了解 TensorFlow是一个数值计算库 经常用于深度学习应用 而Scikit learn是一个通用机器学习框架 但它们之间的确切区别是什么 TensorFlow 的目的和功能是什么 我可以一起使用它
  • ValueError:没有为“dense_input”提供数据

    我正在使用以下简单的代码使用tensorflow加载csv并使用keras执行建模 无法弄清楚这个错误 import tensorflow as tf train dataset fp tf keras utils get file fna
  • 如何重现 Ridge(normalize=True) 的行为?

    这段代码 from sklearn pipeline import make pipeline from sklearn preprocessing import StandardScaler from sklearn linear mod
  • 从sklearn PCA获取特征值和向量

    如何获取 PCA 应用程序的特征值和特征向量 from sklearn decomposition import PCA clf PCA 0 98 whiten True converse 98 variance X train clf f
  • ValueError:不支持连续[重复]

    这个问题在这里已经有答案了 我正在使用 GridSearchCV 进行线性回归的交叉验证 不是分类器也不是逻辑回归 我还使用 StandardScaler 对 X 进行标准化 我的数据框有 17 个特征 X 和 5 个目标 y 观察 约11
  • randomForest 包在删除一个预测类时的奇怪行为

    我正在运行一个随机森林模型 它产生的结果从统计角度来看对我来说完全没有意义 因此我确信有些东西mustrandomForest 包的代码出现错误 至少在模型的本次迭代中 预测 左侧变量是具有 3 种可能结果的政党 ID 民主党 独立党 共和
  • sklearn 估计器管道的参数无效

    我正在实现 O Reilly 书中的一个示例 Python 机器学习简介 使用 Python 2 7 和 sklearn 0 16 我正在使用的代码 pipe make pipeline TfidfVectorizer LogisticRe
  • Keras model.predict 函数给出输入形状错误

    我已经在 Tensorflow 中实现了通用句子编码器 现在我正在尝试预测句子的类概率 我也将字符串转换为数组 Code if model model type universal classifier basic class probs
  • 在 Keras 模型中删除然后插入新的中间层

    给定一个预定义的 Keras 模型 我尝试首先加载预先训练的权重 然后删除一到三个模型内部 非最后几层 层 然后用另一层替换它 我似乎找不到任何有关的文档keras io https keras io 即将做这样的事情或从预定义的模型中删除
  • sklearn 中的 pca.inverse_transform

    将我的数据拟合后 X 我的数据 pca PCA n components 1 pca fit X X pca pca fit transform X 现在 X pca 具有一维 当我根据定义执行逆变换时 它不是应该返回原始数据 即 X 二维
  • 如何创建增量NER训练模型(追加到现有模型中)?

    我正在训练定制命名实体识别 NER 模型使用斯坦福自然语言处理但问题是我想要重新训练模型 Example 假设我训练过xyz模型 然后我将在一些文本上测试它 如果模型检测到错误 那么我 最终用户 将更正它并希望在更正的文本上重新训练 追加模
  • 如何使用 AdaBoost 进行特征选择?

    我想使用 AdaBoost 从大量 100k 中选择一组好的特征 AdaBoost 的工作原理是迭代功能集并根据功能的执行情况添加功能 它选择对现有特征集错误分类的样本表现良好的特征 我目前正在 Open CV 中使用CvBoost 我得到
  • keras 模型拟合:ValueError:无法找到可以处理输入的数据适配器:

    我正在构建一个简单的 CNN 模型用于多类分类 训练和测试数据位于data path根据所需的类子目录flow from directory的函数ImageDataGenerator 这是我根据数据构建和训练模型的代码 from tenso
  • 与 GridSearchCV 的并行错误,与其他方法一起工作正常

    我使用 GridSearchCV 时遇到以下问题 它在使用时给我一个并行错误n jobs gt 1 同时n jobs gt 1与 RadonmForestClassifier 等单一模型配合良好 下面是一个显示错误的简单工作示例 train

随机推荐

  • 从应用程序将照片上传到 Facebook 相册

    我用过 req perms gt publish stream status update 我收到的错误是 致命错误 未捕获的 CurlException 26 创建在 facebook php 第 589 行抛出的表单数据失败 我的上传代
  • 悬停无法与 jQuery 工具一起使用 - jQuery

    当我添加jQuery 工具到我的页面 链接上的悬停效果不起作用 没有它 它也能工作
  • 在 NSTextField 上按下 Enter 键时如何执行某些操作

    我正在使用 Swift 为 Mac 编写一个应用程序 我在 NSTextField 对象中写入一个字符串 我想将其保存在 txt 文件中 我希望用户按下 Enter 键后立即发生这种情况 我的方法 writeToFile 准备好了 我不知道
  • 解压缩来自 WebClient 的 gzip 响应

    有没有一种快速的方法来解压缩使用 WebClient DownloadString 方法下载的 gzip 响应 您对如何使用 WebClient 处理 gzip 响应有什么建议吗 最简单的方法是使用内置的自动减压与HttpWebReques
  • GNU JavaMail:没有地址提供者:rfc822

    使用 OpenJDK 1 7 0 和 GNU JavaMail 1 1 2 在实际消息发送调用期间 SMTPTransport send msg 有时候是这样的 javax mail NoSuchProviderException No p
  • 在 Apple 审核之前获取 App Store URL

    在应用程序的 beta 测试阶段 在 Apple 审核该应用程序之前 是否可以生成应用程序商店 URL 我想在我的应用程序中添加一个指向 App Store 中我的应用程序的链接 用户可以与朋友分享该链接 我希望在 Beta 测试阶段提供此
  • 更改MFC控件中背景和标题的颜色

    我想更改 MFC 应用程序中的编辑控件 静态控件和按钮控件的文本颜色和背景颜色 该控件位于一个CDialogEx对话 我尝试添加 OnCtlColor 使用 Visual Studio 中的向导 在 WM CTLCOLR 消息上 但我无法设
  • 在 WKWebView 中禁用 cookie

    是否可以在 WKWebView 中禁用 cookie 和本地存储 假设这是我的设置 我想添加一些禁用它们的内容 import UIKit import WebKit class ViewController UIViewController
  • Python pandas 不识别特殊字符

    我正在尝试使用df column name str count 在 python pandas 中 但我收到 错误 没有可重复的 对于常规字符 该方法有效 例如df column name str count a 工作正常 另外 符号也有问
  • 生成包含条件项的列表

    是否可以创建一个包含条件项的数组 my a 1 condition 2 no op 3 这样 no op 是一个函数 如果 condition是假的 然后我得到列表 1 3 but if condition是真的 我明白了 1 2 3 背景
  • 使用自定义 UIBezierPath 剪切图像

    我想知道是否有人可以为我指出这个问题的正确方向 我有一个用户创建的UIBezierPath有几个点是由用户触摸引起的 我可以使用这些在沼泽标准 UIView 上创建形状 myPath fill 功能 我理想中想做的是使用路径为 UIImag
  • MySQL 布尔全文搜索中的“显示除所有内容”

    使用 MySQL 布尔全文搜索 http dev mysql com doc refman 5 1 en fulltext boolean html 前导减号表示 这个词不能出现在任何 返回的行数 注意 运算符仅用于 排除其他行 与其他搜索
  • 使用Java/JSP打印支票

    我正在开发一个现有的 Java Web 应用程序 此特定应用程序中的 HTML CSS JS JSP Servlet 和 Java 类 该应用程序当前使用小程序来打印支票 我的老板最近来找我 告诉我在针对最新版本的 Java 测试支票打印时
  • 链接消费者 Java 8

    您好 我遇到以下问题 假设我们有对象 Account 该对象 Account 是不可变的 因此随着时间的推移 我们对其执行操作 实际上是将其转换为另一种状态 例如Account可以变成ClosedAccount或NewAccount等等 现
  • 在 Android Studio 中添加新模块时 java.lang.NoClassDefFoundError: android.support.v4.app.NavUtilsJB 错误

    添加新模块时出现奇怪的错误 https github com lomza android color picker 到我的项目 如果没有这个模块 项目运行正常 但是如果将此项目作为模块添加到我的主项目中并编译它 一切看起来都很好 但应用程序
  • 数据库速度优化:少表多行,还是多表少行?

    我有一个很大的疑问 让我们以任何公司订单的数据库为例 假设这家公司每月大约发出 2000 个订单 那么 每年大约 24K 个订单 他们不想删除任何订单 即使它已经有 5 年了 嘿 这是一个例子 数字并不意味着任何事物 就拥有良好的数据库查询
  • 将匿名函数传递给具有局部变量的命名函数时,Javascript 中的范围问题

    对这个标题感到抱歉 我不知道如何表达它 这是场景 我有一个构建元素的函数 buildSelect id cbFunc 在 buildSelect 中它执行以下操作 select attachEvent onchange cbFunc 我还有
  • 创建 PDF 的最佳 C# API [关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心以获得指导 您能推荐任何适用于 C 的
  • awk 打印每个类别的所有最小值

    想要打印基于的所有最小值1 美元和 3 美元组合 如果有两条或多条线路可用 对于具有 1 和 3 唯一组合的最小值 则需要打印所有行 例如 1 Abc 的最小值 3 10 出现两次 即 Abc yyy 10 aaa 和 Abc ttt 10
  • CatBoost 精度不平衡类

    我使用 CatBoostClassifier 我的类高度不平衡 我应用了一个scale pos weight参数来解决这个问题 在使用评估数据集 测试 进行训练时 CatBoost 在测试中显示出很高的精度 然而 当我使用预测方法对测试进行