从 SKlearn 决策树中检索决策边界线(x,y 坐标格式)

2023-11-26

我正在尝试在外部可视化平台上创建曲面图。我正在使用 iris 数据集sklearn 决策树文档页面。我还使用相同的方法来创建决策曲面图。但我的最终目标不是 matplot lib 视觉效果,因此从这里我将数据输入到我的可视化软件中。为此,我刚刚打电话flatten() and tolist() on xx, yy and Z并编写了一个包含这些列表的 JSON 文件。

问题是当我尝试绘制它时,我的可视化程序崩溃了。事实证明数据太大了。展平后,列表的长度>86,000。这是因为步长/绘图步长非常小.02。因此,它本质上是根据模型的预测,在数据的最小值和最大值范围内迈出一小步,并绘制/填充数据。它有点像像素网格;我将大小缩小到只有 2000 个的数组,并注意到坐标只是来回的线(最终包含整个坐标平面)。

问题:我可以检索决策边界线本身的 x,y 坐标(而不是在整个平面上迭代)吗?理想情况下,列表仅包含每条线的转折点。或者,是否有其他完全不同的方式来重新创建该图,以便计算效率更高?

这可以通过替换来可视化contourf()打电话给countour():

enter image description here

我只是不确定如何检索管理这些线路的数据(通过xx, yy and Z或者可能有其他方式?)。

Note:我对包含行格式的列表/或数据结构的确切格式并不挑剔,只要其计算效率高即可。例如,对于上面的第一个图,一些红色区域实际上是预测空间中的岛屿,因此这可能意味着我们必须像处理自己的线一样处理它。我猜想只要类与 x,y 坐标相结合,使用多少个数组(包含坐标)来捕获决策边界就无关紧要。


决策树没有很好的边界。它们具有多个边界,将特征空间分层划分为矩形区域。

在我的实施中节点收获我编写了解析 scikit 决策树并提取决策区域的函数。对于这个答案,我修改了部分代码以返回与树决策区域相对应的矩形列表。使用任何绘图库应该很容易绘制这些矩形。这是一个使用 matplotlib 的示例:

n = 100
np.random.seed(42)
x = np.concatenate([np.random.randn(n, 2) + 1, np.random.randn(n, 2) - 1])
y = ['b'] * n + ['r'] * n
plt.scatter(x[:, 0], x[:, 1], c=y)

dtc = DecisionTreeClassifier().fit(x, y)
rectangles = decision_areas(dtc, [-3, 3, -3, 3])
plot_areas(rectangles)
plt.xlim(-3, 3)
plt.ylim(-3, 3)

enter image description here

不同颜色区域相遇的地方就有决策边界。我想通过适度的努力就有可能提取出这些边界线,但我将把它留给任何感兴趣的人。

rectangles是一个numpy数组。每一行对应一个矩形,列是[left, right, top, bottom, class].


更新:Iris 数据集的应用

Iris 数据集包含三个类,而不是示例中的 2 个类。所以我们必须添加另一种颜色plot_areas功能:color = ['b', 'r', 'g'][int(rect[4])]。 此外,数据集是 4 维的(包含四个特征),但我们只能在 2D 中绘制两个特征。我们需要选择绘制哪些特征并告诉decision_area功能。该函数有两个参数x and y- 这些是分别位于 x 轴和 y 轴上的特征。默认为x=0, y=1它适用于任何具有多个特征的数据集。然而,在 Iris 数据集中,第一个维度不是很有趣,因此我们将使用不同的设置。

功能decision_areas也不知道数据集的范围。通常,决策树具有向无穷延伸的开放决策范围(例如,每当萼片长度小于xyz这是B类)。在这种情况下,我们需要人为地缩小绘图范围。我选择了-3..3对于示例数据集,但对于 iris 数据集,其他范围是合适的(永远不会有负值,某些特征超出 3)。

这里我们绘制了最后两个特征在 0..7 和 0..5 范围内的决策区域:

from sklearn.datasets import load_iris
data = load_iris()
x = data.data
y = data.target
dtc = DecisionTreeClassifier().fit(x, y)
rectangles = decision_areas(dtc, [0, 7, 0, 5], x=2, y=3)
plt.scatter(x[:, 2], x[:, 3], c=y)
plot_areas(rectangles)

enter image description here

请注意左上角的红色和绿色区域有奇怪的重叠。发生这种情况是因为树在四个维度上做出决策,但我们只能显示两个维度。确实没有一个干净的方法来解决这个问题。高维分类器在低维空间中通常没有良好的决策边界。

因此,如果您对分类器更感兴趣,这就是您所得到的。您可以沿着各种尺寸组合生成不同的视图,但表示的有用性受到限制。

但是,如果您对数据比对分类器更感兴趣,则可以在拟合之前限制维数。在这种情况下,分类器仅在二维空间中做出决策,我们可以绘制漂亮的决策区域:

from sklearn.datasets import load_iris
data = load_iris()
x = data.data[:, [2, 3]]
y = data.target
dtc = DecisionTreeClassifier().fit(x, y)
rectangles = decision_areas(dtc, [0, 7, 0, 3], x=0, y=1)
plt.scatter(x[:, 0], x[:, 1], c=y)
plot_areas(rectangles)

enter image description here


最后,这是实现:

import numpy as np
from collections import deque
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import _tree as ctree
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle


class AABB:
    """Axis-aligned bounding box"""
    def __init__(self, n_features):
        self.limits = np.array([[-np.inf, np.inf]] * n_features)

    def split(self, f, v):
        left = AABB(self.limits.shape[0])
        right = AABB(self.limits.shape[0])
        left.limits = self.limits.copy()
        right.limits = self.limits.copy()

        left.limits[f, 1] = v
        right.limits[f, 0] = v

        return left, right


def tree_bounds(tree, n_features=None):
    """Compute final decision rule for each node in tree"""
    if n_features is None:
        n_features = np.max(tree.feature) + 1
    aabbs = [AABB(n_features) for _ in range(tree.node_count)]
    queue = deque([0])
    while queue:
        i = queue.pop()
        l = tree.children_left[i]
        r = tree.children_right[i]
        if l != ctree.TREE_LEAF:
            aabbs[l], aabbs[r] = aabbs[i].split(tree.feature[i], tree.threshold[i])
            queue.extend([l, r])
    return aabbs


def decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
    """ Extract decision areas.

    tree_classifier: Instance of a sklearn.tree.DecisionTreeClassifier
    maxrange: values to insert for [left, right, top, bottom] if the interval is open (+/-inf) 
    x: index of the feature that goes on the x axis
    y: index of the feature that goes on the y axis
    n_features: override autodetection of number of features
    """
    tree = tree_classifier.tree_
    aabbs = tree_bounds(tree, n_features)

    rectangles = []
    for i in range(len(aabbs)):
        if tree.children_left[i] != ctree.TREE_LEAF:
            continue
        l = aabbs[i].limits
        r = [l[x, 0], l[x, 1], l[y, 0], l[y, 1], np.argmax(tree.value[i])]
        rectangles.append(r)
    rectangles = np.array(rectangles)
    rectangles[:, [0, 2]] = np.maximum(rectangles[:, [0, 2]], maxrange[0::2])
    rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
    return rectangles

def plot_areas(rectangles):
    for rect in rectangles:
        color = ['b', 'r'][int(rect[4])]
        print(rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1])
        rp = Rectangle([rect[0], rect[2]], 
                       rect[1] - rect[0], 
                       rect[3] - rect[2], color=color, alpha=0.3)
        plt.gca().add_artist(rp)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

从 SKlearn 决策树中检索决策边界线(x,y 坐标格式) 的相关文章

随机推荐

  • flutter查询firestore中的多个集合

    我正在玩 flutter 但我遇到了 firestore 的问题 我无法弄清楚 假设我想检索客户的购买者历史记录 并且我有一个如下所述的 Firestore 因此我有一个 用户 集合 其中包含以下文档user id然后在其中 我有一个 产品
  • Identity 2.1 - 未找到 UserId 但之前可以使用

    该代码之前多次工作 但在 Identity 2 1 中为用户添加几个新属性后 它突然停止工作 尽管在调试器中可以看到 UserId 的值 但我收到了 UserId not found 错误 任何人都知道为什么会突然发生这种情况 至少看到这一
  • 窗口卸载事件的本地存储

    我使用本地存储来存储一些数据 用户发出 ajax 请求来获取信息 我将结果存储在存储中 以便下次他请求相同的信息时 我首先在存储中查看它是否存在 现在我意识到 将数据保存在内存中的对象中 并在需要时循环该数据 而不是循环本地存储 实际上更有
  • Delphi 的 WebSocket 客户端实现

    Delphi 有免费的 WebSocket 客户端实现吗 我只找到了这个 WebSockets Delphi 组件 但它不是免费的 这是我的开源库 https github com andremussche DelphiWebsockets
  • Django 2.1 - 'functools.partial' 对象没有属性 '__name__'

    我最近将 Django 从 2 0 7 升级到 2 1 1 出现了一个新错误 其中出现此错误 functools partial object has no attribute name 我想了解我的修复是否正确以及是什么导致了这个新错误的
  • 对核心数据实体进行排序的最佳方法是什么?

    我有一个完全正常工作的核心数据模型 但是当我使用获取请求返回数据时 它的顺序看似随机 对这些数据进行排序的最佳方法是什么 是使用核心数据模型中的另一个表 然后 查询 第一个表吗 或者是将数据拉入数组中 然后以这种方式排序 我不太确定如何做其
  • ggplot 函数在图例下方添加文本

    在 R 中 我想创建一个函数 它接受 ggplot 对象和一些文本并返回 ggplot 对象 方法是在图例下方添加文本 在图的右侧 同时将图例保留在右侧 myplot ggplot iris aes x Sepal Length y Sep
  • 单一来源项目结构有哪些缺点?

    我是目前公司的新人 正在从事由我的直接团队领导编写的项目 该公司通常不使用 C 但我的同事用 C C 编写了高效的代码 只有我们知道如何用 C 编码 我和我的领导 所以没有第三种意见可以涉及 在我对这个项目有了足够的了解之后 我意识到整个结
  • 什么是无界数组?

    什么是无界数组 无界数组和动态分配数组有什么区别 与无界数组相关的常见操作有哪些 就像我们有堆栈数据结构的弹出和推送 无界数组可以 并且通常是 静态分配 实现无界数组时的主要关注点是提供类似动态数组的自由来在运行时决定数组大小 而不会因运行
  • 与序列化相比,使用 MarshalByRefObject 的成本有多高?

    在我的 Azure Web 角色代码中 我有一个CustomIdentity类派生自System Security Principal IIdentity 在某些时候 NET 运行时尝试序列化该类 and 序列化不起作用 试图解决我搜索了很
  • 为什么 range-for 找不到 std::istream_iterator 的 begin 和 end 重载?

    我有这样的代码 std ifstream file filename std ios base in if file good file imbue std locale std locale new delimeter tokens fo
  • 扩展此类以在列表视图中撤消/重做

    我正在使用第三方代码来管理 Windows 窗体项目中的撤消 重做操作 我需要扩展该类来管理列表视图中的撤消 重做操作 这意味着 撤消 重做添加 删除项目和子项目 撤消 重做检查 取消检查行 撤消 重做一些我可能错过的其他重要事情 我不知道
  • 环回模型中的动态属性或聚合函数

    我将如何在环回模型中使用聚合函数 如果我有一个由 mysql 数据库支持的模型 我是否可以让 Model1 与 Model2 具有 hasMany 关系 具有给定的数字属性 并在 Model1 中拥有一个从 Model2 获取该字段的 SU
  • 如何从 php 中的字符串中获取确定数量的单词?

    这就是我正在尝试做的事情 我有一段文本 我想从字符串中提取前 50 个单词 而不切断中间的单词 这就是为什么我更喜欢单词而不是字符 然后我可以使用 left 函数 我知道 str word count var 函数将返回字符串中的单词数 但
  • PL/SQL中如何查看变量的类型?

    PL SQL 中是否有函数可以显示变量的确切类型 例如 SQL 中的 DUMP 函数 我尝试过以下方法 DECLARE l variable INTEGER 1 BEGIN DBMS OUTPUT PUT LINE DUMP l varia
  • 单击文件上传按钮后出现延迟?

    当我单击 选择要上传的文件 即输入类型 文件 时 从单击按钮和选择文件到在按钮旁边显示所选文件之间存在延迟 浏览器是否正在尝试将文件加载到浏览器中 为什么有延迟 接下来 我如何显示 请稍候 消息立即地选择文件后 我尝试了各种 JQ 选项 似
  • javascript 如何判断对象是否存在于数组中

    我有一个 JavaScript 对象数组 与此类似的东西 var objectArray Name A Id 1 Name B Id 2 Name C Id 3 Name D Id 4 现在我试图找出一个对象是否具有给定的属性Name值存在
  • 如何在 jQuery 中选择“this”内的元素?

    我知道我可以这样选择一个元素 ul topnav gt li target css border 3px double red 但我该怎么做 this gt li target css border 3px double red this
  • 将图像存储到 Access 数据库的附件字段中

    我正在编写一个 VB 应用程序 需要在数据库中存储图像 用户在计算机上选择图像 这会以字符串形式提供路径 这是我的尝试 但是我收到错误 INSERT INTO 查询不能包含多值字段 这是我的代码 Dim buff As Byte Nothi
  • 从 SKlearn 决策树中检索决策边界线(x,y 坐标格式)

    我正在尝试在外部可视化平台上创建曲面图 我正在使用 iris 数据集sklearn 决策树文档页面 我还使用相同的方法来创建决策曲面图 但我的最终目标不是 matplot lib 视觉效果 因此从这里我将数据输入到我的可视化软件中 为此 我