我训练过一个梯度提升分类器 http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingClassifier.html#sklearn.ensemble.GradientBoostingClassifier,我想使用所示的 graphviz_exporter 工具将其可视化here http://scikit-learn.org/stable/modules/tree.html.
当我尝试时我得到:
AttributeError: 'GradientBoostingClassifier' object has no attribute 'tree_'
这是因为 graphviz_exporter 的目的是决策树 http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier,但我想仍然有一种方法可以将其可视化,因为梯度增强分类器必须有一个底层决策树。
怎么做?
属性估计器包含底层决策树。以下代码显示经过训练的 GradientBoostingClassifier 的树之一。请注意,虽然整体是一个分类器,但每个单独的树都会计算浮点值。
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.tree import export_graphviz
import numpy as np
# Ficticuous data
np.random.seed(0)
X = np.random.normal(0,1,(1000, 3))
y = X[:,0]+X[:,1]*X[:,2] > 0
# Classifier
clf = GradientBoostingClassifier(max_depth=3, random_state=0)
clf.fit(X[:600], y[:600])
# Get the tree number 42
sub_tree_42 = clf.estimators_[42, 0]
# Visualization
# Install graphviz: https://www.graphviz.org/download/
from pydotplus import graph_from_dot_data
from IPython.display import Image
dot_data = export_graphviz(
sub_tree_42,
out_file=None, filled=True, rounded=True,
special_characters=True,
proportion=False, impurity=False, # enable them if you want
)
graph = graph_from_dot_data(dot_data)
png = graph.create_png()
# Save (optional)
from pathlib import Path
Path('./out.png').write_bytes(png)
# Display
Image(png)
42号树:
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)