DecisionTreeRegressor 的 Predict_proba 的等效项

2024-03-14

scikit-learn 的DecisionTreeClassifier支持通过以下方式预测每个类别的概率predict_proba()功能。这不存在于DecisionTreeRegressor:

AttributeError:“DecisionTreeRegressor”对象没有属性“predict_proba”

我的理解是,决策树分类器和回归器之间的基本机制非常相似,主要区别在于回归器的预测是作为潜在叶子的平均值来计算的。所以我希望能够提取每个值的概率。

是否有另一种方法来模拟这个,例如通过处理树结构 https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html#sphx-glr-auto-examples-tree-plot-unveil-tree-structure-py? The code https://github.com/scikit-learn/scikit-learn/blob/55bf5d9/sklearn/tree/tree.py#L804 for DecisionTreeClassifier's predict_proba不能直接转让。


该函数改编自以下代码赫尔潘德尔的回答 https://stackoverflow.com/a/53592809/1840471提供每个结果的概率:

from sklearn.tree import DecisionTreeRegressor
import pandas as pd

def decision_tree_regressor_predict_proba(X_train, y_train, X_test, **kwargs):
    """Trains DecisionTreeRegressor model and predicts probabilities of each y.

    Args:
        X_train: Training features.
        y_train: Training labels.
        X_test: New data to predict on.
        **kwargs: Other arguments passed to DecisionTreeRegressor.

    Returns:
        DataFrame with columns for record_id (row of X_test), y 
        (predicted value), and prob (of that y value).
        The sum of prob equals 1 for each record_id.
    """
    # Train model.
    m = DecisionTreeRegressor(**kwargs).fit(X_train, y_train)
    # Get y values corresponding to each node.
    node_ys = pd.DataFrame({'node_id': m.apply(X_train), 'y': y_train})
    # Calculate probability as 1 / number of y values per node.
    node_ys['prob'] = 1 / node_ys.groupby(node_ys.node_id).transform('count')
    # Aggregate per node-y, in case of multiple training records with the same y.
    node_ys_dedup = node_ys.groupby(['node_id', 'y']).prob.sum().to_frame()\
        .reset_index()
    # Extract predicted leaf node for each new observation.
    leaf = pd.DataFrame(m.decision_path(X_test).toarray()).apply(
        lambda x:x.to_numpy().nonzero()[0].max(), axis=1).to_frame(
            name='node_id')
    leaf['record_id'] = leaf.index
    # Merge with y values and drop node_id.
    return leaf.merge(node_ys_dedup, on='node_id').drop(
        'node_id', axis=1).sort_values(['record_id', 'y'])

示例(参见这个笔记本 https://colab.research.google.com/drive/1O475-dUdJNtwg8osS8FfpVH_QeMRwf9g?usp=sharing):

from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
X, y = load_boston(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
# Works better with min_samples_leaf > 1.
res = decision_tree_regressor_predict_proba(X_train, y_train, X_test,
                                            random_state=0, min_samples_leaf=5)
res[res.record_id == 2]
#      record_id       y        prob
#   25         2    20.6    0.166667
#   26         2    22.3    0.166667
#   27         2    22.7    0.166667
#   28         2    23.8    0.333333
#   29         2    25.0    0.166667
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

DecisionTreeRegressor 的 Predict_proba 的等效项 的相关文章

  • 如何查找分布式dask中任务失败的原因?

    我正在开发一个分布式计算系统dask distributed 我通过以下方式提交给它的任务Executor map功能有时会失败 而其他看起来相同的功能却可以成功运行 该框架是否提供了诊断问题的方法 update我所说的失败是指增加 Bok
  • Python3+Kivy+Plyer 推送通知图标问题

    我在使用 Android 的简单通知测试应用程序时遇到了一个奇怪的错误 错误 python AttributeError type object notification org notificator R drawable has no
  • 在 python 2 和 3 的spyder之间切换

    根据我在文档中了解到的内容 它指出您只需使用命令提示符创建一个新变量即可轻松在 2 个 python 环境之间切换 如果我已经安装了 python 2 7 则 conda create n python34 python 3 4 anaco
  • OpenCV 错误:使用 COLOR_BGR2GRAY 函数时断言失败

    我在使用 opencv 时遇到了一个奇怪的问题 我在 jupyter 笔记本中工作时没有任何问题 但在尝试运行此 Sublime 时却出现问题 错误是 OpenCV错误 cvtColor中断言失败 深度 CV 8U 深度 CV 16U 深度
  • 无法将 datetime.datetime 与 datetime.date 进行比较

    我有以下代码并收到上述错误 由于我是 python 新手 我无法理解这里的语法以及如何修复错误 if not start or date lt start start date 有一个datetime date 从日期时间转换为日期的方法
  • DataFrame 中的字符串,但 dtype 是对象

    为什么 Pandas 告诉我我有对象 尽管所选列中的每个项目都是一个字符串 即使在显式转换之后也是如此 这是我的数据框
  • Pytest:如何使用从夹具返回的列表来参数化测试?

    我想使用由固定装置动态创建的列表来参数化测试 如下所示 pytest fixture def my list returning fixture depends on other fixtures return a dynamically
  • 如何在“python setup.py test”中运行 py.test 和 linter

    我有一个项目setup py文件 我用pytest作为测试框架 我还在我的代码上运行各种 linter pep8 pylint pydocstyle pyflakes ETC 我用tox在多个 Python 版本中运行它们 并使用以下命令构
  • multiprocessing.freeze_support()

    为什么多处理模块需要调用特定的function http docs python org dev library multiprocessing html multiprocessing freeze support在被 冻结 以生成 Wi
  • pandas 两个数据框交叉连接[重复]

    这个问题在这里已经有答案了 我找不到有关交叉联接的任何内容 包括合并 联接或其他一些内容 我需要使用 my function 作为 myfunc 处理两个数据帧 相当于 for itemA in df1 iterrows for itemB
  • 在python中调用subprocess.Popen时“系统找不到指定的文件”

    我正在尝试使用svnmerge py合并一些文件 它在底层使用 python 当我使用它时 我收到一个错误 系统找不到指定的文件 工作中的同事正在运行相同版本的svnmerge py 以及 python 2 5 2 特别是 r252 609
  • 熊猫记忆

    我有冗长的计算 我重复了很多次 因此 我想使用记忆 诸如jug http packages python org Jug and joblib http packages python org joblib memory html 与Pan
  • Matplotlib 将颜色图 tab20 更改为三种颜色

    Matplotlib 有一些新的且非常方便的颜色图 选项卡颜色图 https matplotlib org examples color colormaps reference html 我错过的是生成像 tab20b 或 tab20c 这
  • 使用 pandas 绘制带有误差线的条形图

    我正在尝试从 DataFrame 生成条形图 如下所示 Pre Post Measure1 0 4 1 9 这些值是我从其他地方计算出来的中值 我还有它们的方差和标准差 以及标准误差 我想将结果绘制为具有适当误差线的条形图 但指定多个误差值
  • DRF:以编程方式从 TextChoices 字段获取默认选择

    我们的网站是 Vue 前端 DRF 后端 在一个serializer validate 方法 我需要以编程方式确定哪个选项TextChoices类已被指定为模型字段的默认值 TextChoices 类 缩写示例 class PaymentM
  • Python列表对象属性“append”是只读的

    正如标题所说 在Python中 我试图做到这一点 以便当有人输入一个选择 在本例中为Choice13 时 它会从密码列表中删除旧密码并添加新密码 passwords mrjoebblock mrjoefblock mrjoegblock m
  • 使用Python重命名目录中的多个文件

    我正在尝试使用以下 Python 脚本重命名目录中的多个文件 import os path Users myName Desktop directory files os listdir path i 1 for file in files
  • bool() 和operator.truth() 有什么区别?

    bool https docs python org 3 library functions html bool and operator truth https docs python org 3 library operator htm
  • 如何抑制 Pandas Future 警告?

    当我运行该程序时 Pandas 每次都会给出如下所示的 未来警告 D Python lib site packages pandas core frame py 3581 FutureWarning rename with inplace
  • 从 Flask 中的 S3 返回 PDF

    我正在尝试在 Flask 应用程序的浏览器中返回 PDF 我使用 AWS S3 来存储文件 并使用 boto3 作为与 S3 交互的 SDK 到目前为止我的代码是 s3 boto3 resource s3 aws access key id

随机推荐