使用 xgboost 分类器进行多类分类?

2024-01-29

我正在尝试使用 xgboost 进行多类分类,并使用此代码构建了它,

clf = xgb.XGBClassifier(max_depth=7, n_estimators=1000)

clf.fit(byte_train, y_train)
train1 = clf.predict_proba(train_data)
test1 = clf.predict_proba(test_data)

这给了我一些好的结果。我的案例的对数损失低于 0.7。但浏览了几页后,我发现我们必须使用 XGBClassifier 中的另一个目标来解决多类问题。以下是这些页面的推荐内容。

clf = xgb.XGBClassifier(max_depth=5, objective='multi:softprob', n_estimators=1000, 
                        num_classes=9)

clf.fit(byte_train, y_train)  
train1 = clf.predict_proba(train_data)
test1 = clf.predict_proba(test_data)

该代码也可以工作,但与我的第一个代码相比,它需要花费很多时间才能完成。

为什么我的第一个代码也适用于多类案例?我已经检查过它的默认目标是二元:用于二元分类的逻辑,但它对于多类确实效果很好?如果两者都正确,我应该使用哪一个?


事实上,即使默认的 obj 参数XGBClassifier is binary:logistic,它会内部判断标签y的类别数。当类号大于2时,会修改obj参数为multi:softmax.

https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/sklearn.py https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/sklearn.py

class XGBClassifier(XGBModel, XGBClassifierBase):
    # pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
    def __init__(self, objective="binary:logistic", **kwargs):
        super().__init__(objective=objective, **kwargs)

    def fit(self, X, y, sample_weight=None, base_margin=None,
            eval_set=None, eval_metric=None,
            early_stopping_rounds=None, verbose=True, xgb_model=None,
            sample_weight_eval_set=None, callbacks=None):
        # pylint: disable = attribute-defined-outside-init,arguments-differ

        evals_result = {}
        self.classes_ = np.unique(y)
        self.n_classes_ = len(self.classes_)

        xgb_options = self.get_xgb_params()

        if callable(self.objective):
            obj = _objective_decorator(self.objective)
            # Use default value. Is it really not used ?
            xgb_options["objective"] = "binary:logistic"
        else:
            obj = None

        if self.n_classes_ > 2:
            # Switch to using a multiclass objective in the underlying
            # XGB instance
            xgb_options['objective'] = 'multi:softprob'
            xgb_options['num_class'] = self.n_classes_
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 xgboost 分类器进行多类分类? 的相关文章

  • Django 的内联管理:一个“预填充”字段

    我正在开发我的第一个 Django 项目 我希望用户能够在管理中创建自定义表单 并向其中添加字段当他或她需要它们时 为此 我在我的项目中添加了一个可重用的应用程序 可在 github 上找到 https github com stephen
  • 元组有什么用?

    我现在正在学习 Python 课程 我们刚刚介绍了元组作为数据类型之一 我阅读了它的维基百科页面 但是 我无法弄清楚这种数据类型在实践中会有什么用处 我可以提供一些需要一组不可变数字的示例吗 也许是在 Python 中 这与列表有何不同 每
  • 如何用python脚本控制TP LINK路由器

    我想知道是否有一个工具可以让我连接到路由器并关闭它 然后从 python 脚本重新启动它 我知道如果我写 import os os system ssh l root 192 168 2 1 我可以通过 python 连接到我的路由器 但是
  • 将html数据解析成python列表进行操作

    我正在尝试读取 html 网站并提取其数据 例如 我想查看公司过去 5 年的 EPS 每股收益 基本上 我可以读入它 并且可以使用 BeautifulSoup 或 html2text 创建一个巨大的文本块 然后我想搜索该文件 我一直在使用
  • Pandas 日期时间格式

    是否可以用零后缀表示 pd to datetime 似乎零被删除了 print pd to datetime 2000 07 26 14 21 00 00000 format Y m d H M S f 结果是 2000 07 26 14
  • Python zmq SUB 套接字未接收 MQL5 Zmq PUB 套接字

    我正在尝试在 MQL5 中设置一个 PUB 套接字 并在 Python 中设置一个 SUB 套接字来接收消息 我在 MQL5 中有这个 include
  • 如何使用 Pandas、Numpy 加速 Python 中的嵌套 for 循环逻辑?

    我想检查一下表的字段是否TestProject包含了Client端传入的参数 嵌套for循环很丑陋 有什么高效简单的方法来实现吗 非常感谢您的任何建议 def test parameter a list parameter b list g
  • YOLOv8获取预测边界框

    我想将 OpenCV 与 YOLOv8 集成ultralytics 所以我想从模型预测中获取边界框坐标 我该怎么做呢 from ultralytics import YOLO import cv2 model YOLO yolov8n pt
  • datetime.datetime.now() 返回旧值

    我正在通过匹配日期查找 python 中的数据存储条目 我想要的是每天选择 今天 的条目 但由于某种原因 当我将代码上传到 gae 服务器时 它只能工作一天 第二天它仍然返回相同的值 例如当我上传代码并在 07 01 2014 执行它时 它
  • Python 2:SMTPServerDisconnected:连接意外关闭

    我在用 Python 发送电子邮件时遇到一个小问题 me my email address you recipient s email address me email protected cdn cgi l email protectio
  • Python beautifulsoup 仅限 1 级文本

    我看过其他 beautifulsoup 得到相同级别类型的问题 看来我的有点不同 这是网站 我正试图拿到右边那张桌子 请注意表的第一行如何展开为该数据的详细细分 我不想要那个数据 我只想要最顶层的数据 您还可以看到其他行也可以展开 但在本例
  • 如何使用 Mysql Python 连接器检索二进制数据?

    如果我在 MySQL 中创建一个包含二进制数据的简单表 CREATE TABLE foo bar binary 4 INSERT INTO foo bar VALUES UNHEX de12 然后尝试使用 MySQL Connector P
  • pyspark 将 twitter json 流式传输到 DF

    我正在从事集成工作spark streaming with twitter using pythonAPI 我看到的大多数示例或代码片段和博客是他们从Twitter JSON文件进行最终处理 但根据我的用例 我需要所有字段twitter J
  • 加快网络抓取速度

    我正在使用一个非常简单的网络抓取工具抓取 23770 个网页scrapy 我对 scrapy 甚至 python 都很陌生 但设法编写了一个可以完成这项工作的蜘蛛 然而 它确实很慢 爬行 23770 个页面大约需要 28 小时 我看过scr
  • 如何在 Windows 命令行中使用参数运行 Python 脚本

    这是我的蟒蛇hello py script def hello a b print hello and that s your sum sum a b print sum import sys if name main hello sys
  • 如何在 pygtk 中创建新信号

    我创建了一个 python 对象 但我想在它上面发送信号 我让它继承自 gobject GObject 但似乎没有任何方法可以在我的对象上创建新信号 您还可以在类定义中定义信号 class MyGObjectClass gobject GO
  • 使用for循环时如何获取前一个元素? [复制]

    这个问题在这里已经有答案了 可能的重复 Python 循环内的上一个和下一个值 https stackoverflow com questions 1011938 python previous and next values inside
  • Django-tables2 列总计

    我正在尝试使用此总结列中的所有值文档 https github com bradleyayers django tables2 blob master docs pages column headers and footers rst 但页
  • 如何计算Python中字典中最常见的前10个值

    我对 python 和一般编程都很陌生 所以请友善 我正在尝试分析包含音乐信息的 csv 文件并返回最常听的前 n 个乐队 从下面的代码中 每听一首歌曲都是一个列表中的字典条目 格式如下 album Exile on Main Street
  • Pandas 每周计算重复值

    我有一个Dataframe包含按周分组的日期和 ID df date id 2022 02 07 1 3 5 4 2022 02 14 2 1 3 2022 02 21 9 10 1 2022 05 16 我想计算每周有多少 id 与上周重

随机推荐