sklearn决策树怎么使用ccp_alpha进行剪枝

2023-11-18

本站原创文章,转载请说明来自《老饼讲解-机器学习》ml.bbbdata.com


目录

一.CCP后剪枝是什么  

二.如何通过ccp_alpha进行后剪枝

(1) 查看CCP路径  

(2)根据CCP路径剪树   

三、完整CCP剪枝应用实操DEMO

四、CCP路径是怎么计算出来的


本文讲解sklearn中决策树ccp_alpha参数的使用方法,它主要用于ccp后剪枝。

ccp_alpha的值要设为多少?
事实上,一般并不是直接设置ccp_alpha,
而是要先打印CCP路径,再根据路径信息来决定alpha的值。
CCP路径又是什么鬼?
本文一一道来。

 

一.CCP后剪枝是什么  


后剪枝一般指的是CCP代价复杂度剪枝法(Cost Complexity Pruning),
即在树构建完成后,对树进行剪枝简化,使以下损失函数最小化
L = \displaystyle \sum \limits _{i=1}^{T} \frac{N_i}{N} L_i +\alpha T
T :叶子节点个数                     
N :所有样本个数                      
N_i:第 i 个叶子节点上的样本数 
L_i: 第i个叶子节点的损失函数  
\alpha:待定系数,用于惩罚节点个数,引导模型用更少的节点。   

损失函数既考虑了代价,又考虑了树的复杂度,所以叫代价复杂度剪枝法,
实质就是在树的复杂度与准确性之间取得一个平衡点。


备注:在sklearn中,如果criterion设为GINI,Li 则是每个叶子节点的GINI系数,如果设为entropy,则是熵。


 


二.如何通过ccp_alpha进行后剪枝


具体操作过程如下:

(1) 查看CCP路径  

计算CCP路径,查看alpha与树质量的关系:
构建好树后,我们可以通过clf.cost_complexity_pruning_path(X, y) 查看树的CCP路径,
Demo代码如下:

# -*- coding: utf-8 -*-
from sklearn.datasets import load_iris
from sklearn import tree
import numpy as np
#----------------数据准备----------------------------
iris = load_iris()                          # 加载数据
X = iris.data
y = iris.target
#---------------模型训练---------------------------------
clf = tree.DecisionTreeClassifier(min_samples_split=10,ccp_alpha=0)        
clf = clf.fit(X, y)     
#-------计算ccp路径-----------------------
pruning_path = clf.cost_complexity_pruning_path(X, y)
#-------打印结果---------------------------    
print("\n====CCP路径=================")
print("ccp_alphas:",pruning_path['ccp_alphas'])
print("impurities:",pruning_path['impurities']) 


运行结果:

====sklearn的CCP路径=================
ccp_alphas: [0.      0.00415459 0.01305556 0.02966049 0.25979603 0.33333333]
impurities: [0.02666667 0.03082126 0.04387681 0.07353731 0.33333333 0.66666667]

它的意思是:

0<α<0.00415时,树的不纯度为 0.02666,
0.00415< α<0.01305时,树的不纯度为 0.03082,
0.01305<α<0.02966时,树的不纯度为 0.04387,
........
其中,树的不纯度指的是损失函数的前部分L = \displaystyle \sum \limits _{i=1}^{T} \frac{N_i}{N} L_i, 

也即所有叶子的不纯度(gini或者熵)加权和.

小贴士  
ccp_path只提供树的不纯度,
如果还需要alpha对应的其它信息,
则可以将alpha代入模型中训练,
从训练好的模型中获取。 

(2)根据CCP路径剪树   

根据树的质量,选定alpha进行剪树
 我们根据业务实际情况,选择一个可以接受的树不纯度,
找到对应的alpha,
例如,我们可接受的树不纯度为0.0735,
则alpha可设为0.1(在0.02966与0.25979之间)
对模型重新以参数ccp_alpha=0.1进行训练,
即可得到剪枝后的决策树。


三、完整CCP剪枝应用实操DEMO

完整代码如下:

 # -*- coding: utf-8 -*-
from sklearn.datasets import load_iris
from sklearn import tree
import numpy as np

#--------数据准备-----------------------------------
iris = load_iris()                          # 加载数据
X = iris.data
y = iris.target
#-------模型训练---------------------------------
clf = tree.DecisionTreeClassifier(min_samples_split=10,random_state=0,ccp_alpha=0)        
clf = clf.fit(X, y)     
#-------计算ccp路径------------------------------
pruning_path = clf.cost_complexity_pruning_path(X, y)

#-------打印结果---------------------------------   
print("\n====CCP路径=================")
print("ccp_alphas:",pruning_path['ccp_alphas'])
print("impurities:",pruning_path['impurities'])    

#------设置alpha对树后剪枝-----------------------
clf = tree.DecisionTreeClassifier(min_samples_split=10,random_state=0,ccp_alpha=0.1)        
clf = clf.fit(X, y) 
#------自行计算树纯度以验证-----------------------
is_leaf =clf.tree_.children_left ==-1
tree_impurities = (clf.tree_.impurity[is_leaf]* clf.tree_.n_node_samples[is_leaf]/len(y)).sum()
#-------打印结果--------------------------- 
print("\n==设置alpha=0.1剪枝后的树纯度:=========\n",tree_impurities)


运行结果:

====CCP路径=================
ccp_alphas: [0.      0.00415459 0.01305556 0.02966049 0.25979603 0.33333333]
impurities: [0.02666667 0.03082126 0.04387681 0.07353731 0.33333333 0.66666667]

==设置alpha=0.1剪枝后的树纯度:=========
 0.0735373054213634


四、CCP路径是怎么计算出来的

对于CCP路径的计算过程,本文不再重复讲解,可参考:
1.《决策树后剪枝原理:CCP剪枝法》                      
2.《决策树(sklearn)中CCP路径计算的实现方式.py》


相关文章

《入门篇-环境搭建:anaconda安装》

《​​​​​​入门篇-模型:逻辑回归》

《入门篇-模型:决策树-CART》

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

sklearn决策树怎么使用ccp_alpha进行剪枝 的相关文章

  • Java基础面试题

    1 面向对象的特征有哪些方面 1 抽象 抽象就是忽略一个主题中与当前目标无关的那些方面 以便更充分地注意与当前目标有关的方面 抽象并不打算了解全部问题 而只是选择其中的一部分 暂时不用部分细节 抽象包括两个方面 一是过程抽象 二是数据抽象

随机推荐

  • 十分钟让你明白Objective-C的语法(和Java、C++的对比)

    很多想开发iOS 或者正在开发iOS的程序员以前都做过Java或者C 当第一次看到Objective C的代码时都会头疼 Objective C的代码在语法上和Java C 有着很大的区别 有的同学会感觉像是看天书一样 不过 语言都是相通的
  • smp和mpp计算机

    SMP 是Symmetric Multi Processing的简称 意为对称多处理系统 内有许多紧耦合多处理器 这种系统的最 大特点就是共享所有资源 MPP 另外与之相对立的标准是MPP Massively Parallel Proces
  • Linux驱动

    Linux驱动入门系列 Linux驱动入门 一 字符设备驱动基础 Linux驱动入门 二 操作硬件 Linux驱动入门 三 Led驱动 Linux驱动入门 四 非阻塞方式实现按键驱动 Linux驱动入门 五 阻塞方式实现按键驱动 Linux
  • ​7.1 项目1 学生通讯录管理:文本文件增删改查(C++版本)(自顶向下设计+断点调试) (A)​

    C 自学精简教程 目录 必读 作业目标 这个作业中 你需要综合运用之前文章中的知识 来解决一个相对完整的应用程序 作业描述 1 在这个作业中你需要在文本文件中存储学生通讯录的信息 并在程序启动的时候加载这些数据到内存中 2 在程序运行过程中
  • 用Python绘制六种可视化图表,简直太好用了

    前言 本文的文字及图片来源于网络 仅供学习 交流使用 不具有任何商业用途 如有问题请及时联系我们以作处理 PS 如有需要Python学习资料的小伙伴可以加点击下方链接自行获取 python免费学习资料以及群交流解答点击即可加入 可视化图表
  • 邻接矩阵实现的带权有向图(C++)

    邻接矩阵实现的带权有向图 C 相关概念 定义和声明 实现 1 距离无穷大的定义 2 构造函数 3 深度优先遍历 4 广度优先遍历 6 将邻接矩阵转换为邻接表 7 重载 lt lt 运算符 打印输出 测试 测试代码 测试结果 源代码 相关概念
  • Callable 接口实现java 的多线程

    java 中创建多线程最常见的是继承Thread 的子类重写run 方法 还有就是实现Runnable 接口 我们最好使用实现了Runnable 接口的方法原因有两点 因为java 的单继承的特点 所以说使用第一种方法不能继承其他父类了 采
  • Lunix历史及如何学习

    1 Lunix是什么 1 1 Lunix是操作系统还是应用程序 Lunix是一套操作系统 它提供了一个完整的操作系统当中最底层的硬件控制与资源管理的完整架构 这个架构是沿袭Unix 良好的传统来的 所以相当的稳定而功能强大 Lunix具有核
  • SCI论文润色插件Product Content Checker扩展程序

    下载地址 https www gugeapps net webstore detail product content checker ilmaafbmfcklldgoehebccigadbkbdpc download 打开方式 直接将下载
  • simhash算法原理及实现

    一篇不错的介绍simhash的文章 如下 http blog csdn net chenguolinblog article details 50830948
  • 多个ajax请求时控制执行顺序或者等待执行完成后的操作

    当确保执行顺序时 一 请求加async false 这样所有的ajax就会同步执行 请求顺序就是代码顺序 代码部分 when ajax async false url url1 ajax async false url url2 done
  • ai绘画小程序基于novelai的tag列表源码展示(独家)

    视频 哔哩哔哩 看视频 介绍 一个tag列表展示
  • 代码行统计工具_cloc

    下载并运行 在Github下载稳定发布版本 Releases AlDanial cloc GitHub 直接下载exe文件 放在需要统计代码的文件夹下 用cmd或是powershell运行 cloc 1 96 exe 注意 之前有个空格 c
  • hive 错误 InvalidObjectException(message:Role admin already exists.)

    InvalidObjectException message Role admin already exists at org apache hadoop hive metastore ObjectStore addRole ObjectS
  • python去掉列表中的单引号_从Python中的列表中删除单引号

    我有一个输入字符串 result testing 0 8841 642000 0 80 014521 60 940653 4522126666 1500854400 1500842014000 name 80 014521 60 99653
  • C语言实现顺序表

    线性表是数据结构中的逻辑结构 线性表采用顺序存储的方式存储就称之为顺序表 数组是顺序表在实际编程中的具体实现方式之一 本篇主要介绍顺序表 顺序表的创建 添加元素 删除元素 遍历输出等操作 1 创建顺序表 1 1定义顺序表结构体 结构体包含三
  • Fisco Bcos区块链一(搭建单群组FISCO BCOS联盟链)

    文章目录 区块链开荒 技术文档 https fisco bcos documentation readthedocs io zh CN latest index html 一 搭建第一个区块链网络 1 搭建单群组FISCO BCOS联盟链
  • java基础语法

    java基础语法 1 Java概述 1 1 Java语言发展史 了解 1 2 Java语言跨平台原理 理解 1 3 JRE和JDK 记忆 1 4 JDK的下载和安装 应用 1 4 1 下载 1 4 2 安装 1 4 3 JDK的安装目录介绍
  • python进阶之多线程对同一个全局变量的处理

    通常情况下 from threading import Thread global num 0 def func1 global global num for i in range 1000000 global num 1 print fu
  • sklearn决策树怎么使用ccp_alpha进行剪枝

    本站原创文章 转载请说明来自 老饼讲解 机器学习 ml bbbdata com 目录 一 CCP后剪枝是什么 二 如何通过ccp alpha进行后剪枝 1 查看CCP路径 2 根据CCP路径剪树 三 完整CCP剪枝应用实操DEMO 四 CC