获取sklearn中节点的决策路径

2023-11-21

我想要 scikit-learn 决策树 (DecisionTreeClassifier) 中从根节点到给定节点(我提供)的决策路径(即规则集)。clf.decision_path指定样本经过的节点,这可能有助于获取样本遵循的规则集,但是如何获取直到树中特定节点的规则集?


对于节点的决策规则,使用iris dataset:

from sklearn.datasets import load_iris
from sklearn import tree
import graphviz 

iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf, out_file=None, 
                                feature_names=iris.feature_names,  
                                class_names=iris.target_names,  
                                filled=True, rounded=True,  
                                special_characters=True)  
graph = graphviz.Source(dot_data)  
#this will create an iris.pdf file with the rule path
graph.render("iris")

enter image description here


对于基于样本的路径,请使用:

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

estimator = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
estimator.fit(X_train, y_train)

# The decision estimator has an attribute called tree_  which stores the entire
# tree structure and allows access to low level attributes. The binary tree
# tree_ is represented as a number of parallel arrays. The i-th element of each
# array holds information about the node `i`. Node 0 is the tree's root. NOTE:
# Some of the arrays only apply to either leaves or split nodes, resp. In this
# case the values of nodes of the other type are arbitrary!
#
# Among those arrays, we have:
#   - left_child, id of the left child of the node
#   - right_child, id of the right child of the node
#   - feature, feature used for splitting the node
#   - threshold, threshold value at the node

n_nodes = estimator.tree_.node_count
children_left = estimator.tree_.children_left
children_right = estimator.tree_.children_right
feature = estimator.tree_.feature
threshold = estimator.tree_.threshold

# The tree structure can be traversed to compute various properties such
# as the depth of each node and whether or not it is a leaf.
node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
stack = [(0, -1)]  # seed is the root node id and its parent depth
while len(stack) > 0:
    node_id, parent_depth = stack.pop()
    node_depth[node_id] = parent_depth + 1

    # If we have a test node
    if (children_left[node_id] != children_right[node_id]):
        stack.append((children_left[node_id], parent_depth + 1))
        stack.append((children_right[node_id], parent_depth + 1))
    else:
        is_leaves[node_id] = True

print("The binary tree structure has %s nodes and has "
      "the following tree structure:"
      % n_nodes)
for i in range(n_nodes):
    if is_leaves[i]:
        print("%snode=%s leaf node." % (node_depth[i] * "\t", i))
    else:
        print("%snode=%s test node: go to node %s if X[:, %s] <= %s else to "
              "node %s."
              % (node_depth[i] * "\t",
                 i,
                 children_left[i],
                 feature[i],
                 threshold[i],
                 children_right[i],
                 ))
print()

# First let's retrieve the decision path of each sample. The decision_path
# method allows to retrieve the node indicator functions. A non zero element of
# indicator matrix at the position (i, j) indicates that the sample i goes
# through the node j.

node_indicator = estimator.decision_path(X_test)

# Similarly, we can also have the leaves ids reached by each sample.

leave_id = estimator.apply(X_test)

# Now, it's possible to get the tests that were used to predict a sample or
# a group of samples. First, let's make it for the sample.

# HERE IS WHAT YOU WANT
sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                    node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:

    if leave_id[sample_id] == node_id:  # <-- changed != to ==
        #continue # <-- comment out
        print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--

    else: # < -- added else to iterate through decision nodes
        if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
              % (node_id,
                 sample_id,
                 feature[node_id],
                 X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                 threshold_sign,
                 threshold[node_id]))

这将在最后打印以下内容:

Rules used to predict sample 0: decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011920929) decision id node 2 : (X[0, 2] (= 5.1) > 4.949999809265137) leaf node 4 reached, no decision here


本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

获取sklearn中节点的决策路径 的相关文章

随机推荐

  • Windows 事件查看器锁定了我的 EXE 文件

    我对某件事很好奇 我正在开发一个 Windows 服务并将所有诊断事件记录到 Windows 事件日志中 因此 当服务运行时 我打开事件查看器 从管理工具 来查看服务运行的结果 除了当我需要卸载程序时 再次出于测试目的 这非常有效 出于某种
  • 是否使用辅助角色或 Web 角色:Windows Azure

    我正在编写一个小型计算程序 对 blob 文件进行大量读取操作 我应该选择工作者角色还是网络角色 Web 角色和辅助角色之间的唯一区别是 在 Web 角色中 IIS 实际上是托管 Web 核心 启动并指向您的应用程序数据目录 您仍然可以将代
  • 如果上次修改日期已经过了某个时间,我如何告诉 Camel 仅复制文件?

    我想知道这是否可以用 Apache Camel 来实现 我想做的是让 Camel 查看文件目录 并只复制 上次修改 日期比某个日期更新的文件 例如 仅复制 2014 年 2 月 7 日之后修改的文件 基本上 我想在每次 Camel 运行时更
  • 查找 .NET 程序集中的字节偏移量

    我正在尝试调试客户向我们报告的错误 堆栈跟踪只有字节偏移量 没有行号 e g NullReferenceException 未将对象引用设置为对象的实例 Foo Bar FooFoo p 32Foo BarBar 191Foo BarBar
  • 测试立即失败,并出现未知错误:通过 systemd 运行 Selenium 网格时,DevToolsActivePort 文件不存在

    我一直在尝试改变从 shell 脚本启动 Selenium 网格服务的方式 rclocal to a systemd服务 但不起作用 脚本是这样的 bin bash java jar opt selenium server standalo
  • 关于C++默认值的一些问题

    我对函数参数列表中的默认值有一些疑问 默认值是签名的一部分吗 默认参数的参数类型怎么样 默认值存储在哪里 在堆栈或全局堆中还是在常量数据段中 否 默认argument不是签名的一部分 也不是函数类型的一部分 参数类型是签名的一部分 但默认参
  • 传递所有适用类型的函数

    我遵循了发现的建议here定义一个名为 square 的函数 然后尝试将其传递给一个名为两次的函数 函数定义如下 def square T n T implicit numeric Numeric T T numeric times n n
  • 在 Linux 内核模块中读/写文件

    我知道所有关于为什么不应该从内核读取 写入文件的讨论 而是如何使用 proc or netlink要做到这一点 无论如何我想读 写 我也读过让我发疯 你永远不应该在内核中做的事情 然而问题是2 6 30不导出sys read 相反 它被包裹
  • 我是否需要在 C++ 线程中使用整数锁定

    如果我在多个线程中访问单个整数类型 例如 long int bool 等 我是否需要使用同步机制 例如互斥体 来锁定它们 我的理解是 作为原子类型 我不需要锁定对单个线程的访问 但我看到很多代码确实使用了锁定 对此类代码进行分析表明 使用锁
  • DB2 中的 SQL Server 事务相当于什么?

    DB2 中的以下 SQL Server 语句等效于什么 开始交易 提交交易 回滚事务 答案实际上比这里指出的要复杂一些 确实 事务是 ANSI 标准化的 而 DB2may支持他们 DB2 for z OS 与其他变体 LUW Linux U
  • 重置 IRB 控制台

    如何告别所有定义的常量 对象等in an irb会话回到干净的状态 经过 in 我的意思是不操纵子会话 Type exec 0 在您的 IRB 控制台会话中
  • UIView 纵横比混淆了 systemLayoutSizeFittingSize

    好吧 另一个 UITableViewCell 动态高度问题 但有一点点扭曲 不幸的是我只能在发布时跳转到iOS 8 否则问题就解决了 需要 iOS gt 7 1 我试图实现一个单元格 单元格顶部有两个图像 下面有一个标题标签 下面有一个描述
  • 如何在Sql commandText中传递int参数

    如何像SQL命令参数一样传递整数值 我正在尝试这样 cmd CommandText insert questions cmd Parameters AddWithValue store result store result cmd Par
  • 使用 DirectoryIterator 对文件进行排序

    我正在创建一个目录 列出 lighttpd 的 PHP5 脚本 在给定的目录中 我希望能够列出直接子目录和文件 带有信息 快速搜索后 目录迭代器似乎是我的朋友 foreach new DirectoryIterator as file ec
  • 移动网站设计

    我刚刚使用样式表 即 media print 等 向网站添加了打印功能 并且想知道是否可以使用类似的方法来添加对移动设备的支持 如果没有 我如何检测移动设备 我的页面是 C aspx 我想缩小页面以便于在移动设备上使用 对我有什么建议吗 编
  • 如何在静态类中使用IHttpContextAccessor设置cookie

    我正在尝试创建一个通用的addReplaceCookie静态类中的方法 该方法看起来像这样 public static void addReplaceCookie string cookieName string cookieValue i
  • 如何在Python中解析带有'+'的标签

    当我尝试编译此内容时 出现 无重复 错误 search re compile r a zA Z0 9 s a zA Z0 9 test re I 问题是 号 我该怎么处理 re compile r a zA Z0 9 s a zA Z0 9
  • AVPlayer 不会在 iOS9 中播放来自 url 的视频

    我试图在 UIView 中嵌入 AVPlayer 并从 url 播放 mp4 视频文件 问题是我只收到黑色空白视图 参见屏幕截图 在以前的 iOS 版本中 它对我有用 但自从升级到 iOS9 后 我遇到了这个问题 我的 h 文件如下所示 i
  • 在 ASP.NET 中生成 PDF 文档[重复]

    这个问题在这里已经有答案了 可能的重复 直接将 aspx 转换为 pdf 有没有办法直接从页面输出从asp net生成PDF文档 我的要求是 当用户访问我网站上的页面时 应该有一个条款可以获取 PDF 格式的相同页面 报告 使用iTextS
  • 获取sklearn中节点的决策路径

    我想要 scikit learn 决策树 DecisionTreeClassifier 中从根节点到给定节点 我提供 的决策路径 即规则集 clf decision path指定样本经过的节点 这可能有助于获取样本遵循的规则集 但是如何获取