机器学习--决策树

2023-11-11

一、决策树简介

决策树(DecisionTree),又称为判定树,是另一种特殊的根树,它最初是运筹学中的常用工具之一;之后应用范围不断扩展,目前是人工智能中常见的机器学习方法之一。决策树是一种基于树结构来进行决策的分类算法,我们希望从给定的训练数据集学得一个模型(即决策树),用该模型对新样本分类。决策树可以非常直观展现分类的过程和结果,决策树模型构建成功后,对样本的分类效率也非常高。

二、决策树的优缺点

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。
缺点:可能会产生过度匹配问题。

三、决策树的一般流程

(1)收集数据:可以使用如何方法。
(2)准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。
(3)分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
(4)训练算法:构造树的数据结构。
(5)测试算法:使用经验树计算错误率。
(6)使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在          含义。

四、信息增益

样本有多个属性,该先选哪个样本来划分数据集呢?在划分数据集之前之后信息发生的变化称为信息增益,获得信息增益最高的属性就是最好的选择。

1.信息熵

样本集合D中第k类样本所占的比例p_k(k=1,2,…,|K|),|K|为样本分类的个数,则D的信息熵为:

Ent(D)的值越小,则D的纯度越高。换句话说,信息熵越小,信息增益越大。

2.信息增益

使用属性a对样本集D进行划分所获得的“信息增益”的计算方法是,用样本集的总信息熵

减去属性a的每个分支的信息熵与权重(该分支的样本数除以总样本数)的乘积。

通常,信息增益越大,意味着用属性a进行划分所获得的“纯度提升”越大。我们的目标就是寻找使信息增益最大的属性作为划分的依据。

五、决策树的具体实现

1.收集数据

我收集的数据依旧是集美大学计算机工程学院acm比赛校选的数据,其中每列的属性分别是成绩、用时、年级、奖项。

  

2.准备数据

由于我所用的数据很明显是连续型的,我们需要将数据离散化。这里我随机挑选15个数据进行离散化。 其中将成绩、用时、年级分为4个等级,数字越大,分别代表成绩越高,用时越长,年级越高。奖项分为3个等级,1等奖,2等奖,3等奖。

3.导入数据

用pandas模块的read_csv()函数读取数据文本,分别求出数据集data,标签集labels,所有情况集labels_full

#导入数据
def import_data():
    data = pd.read_csv('data1.txt')
    data.head(10)
    data=np.array(data).tolist()
    # 属性值列表
    labels = ['得分', '用时', '年级', '奖项']

    # 特征对应的所有可能的情况
    labels_full = {}

    for i in range(len(labels)):
        labelList = [example[i] for example in data] #获取每一行的第一个数
        uniqueLabel = set(labelList)#去重
        labels_full[labels[i]] = uniqueLabel#每一个属性所对应的种类
    return data,labels,labels_full

data,labels,labels_full=import_data()

4.计算信息熵

编写计算信息熵的算法,为后面的计算信息增益打下基础

#计算信息熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)#计算数据集总数
    labelCounts = collections.defaultdict(int)#用来统计标签
    for featVec in dataSet:
        currentLabel = featVec[-1]#得到数据的分类标签
        if currentLabel not in labelCounts.keys():#若当前的标签不在标签集中则创建一个
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1 #标签集中对应标签数目加一,统计每个类别

    shannonEnt = 0.0 #信息熵初始值
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries #pk
        shannonEnt -= prob * math.log2(prob)
    return shannonEnt 

#print("当前数据的总信息熵",calcShannonEnt(data))

计算出该数据集的总信息熵

 

5.划分数据集

 我们对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方法。

#划分数据集
def splitDataSet(dataSet, axis, value):# 待划分的数据集 划分数据集的特征 需要返回的特征的值
    retDataSet = [] #创建一个新的列表
    for featVec in dataSet:
        if featVec[axis]==value:#如果给定的特征值是等于想要的特征值
            #将该特征值前面的内容保存起来
            reducedFeatVec = featVec[:axis] 
            #将该特征值后面的内容保存起来
            reducedFeatVec.extend(featVec[axis + 1:])
            #表示去掉在axis中特征值为value的样本后而得到的数据集
            retDataSet.append(reducedFeatVec)
    return retDataSet 

通过计算每个特征的信息增益,我们的目标是找出信息增益最大的的数据集划分方式,这个划分方式就是最好的数据集划分方式。

#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet, labels):
    #得到数据的特征值总数
    numFeatures = len(dataSet[0]) - 1
    #计算出总信息熵
    baseEntropy = calcShannonEnt(dataSet)
    #基础信息增益为0.0
    bestInfoGain = 0.0
    #最好的特征值
    bestFeature = -1
    #对每个特征值进行求信息熵
    for i in range(numFeatures):
        #得到数据集中所有的当前特征值列表
        featList = [example[i] for example in dataSet]
        #去掉重复的
        uniqueVals = set(featList)
        #新的熵,代表当前特征值的熵
        newEntropy = 0.0
        #遍历现在有的特征的可能性
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet=dataSet, axis=i, value=value)#在全部数据集的当前特征位置上,找到该特征值等于当前值的集合
            prob = len(subDataSet) / float(len(dataSet))#计算权重
            newEntropy += prob * calcShannonEnt(subDataSet)#计算当前特征值的熵
        infoGain = baseEntropy - newEntropy#计算信息增益
        print('当前特征值为:' + labels[i] + ',对应的信息增益值为:' + str(infoGain)+"i等于"+str(i))
        #选出最大的信息增益
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i #新的最好的用来划分的特征值
    print('信息增益最大的特征为:' + labels[bestFeature])
    return bestFeature

#print(chooseBestFeatureToSplit(data,labels))

我们发现信息增益的最大值是得分,其次是用时,最后是年级。

6.递归构建决策树

在构建决策树,可能会出现这一种情况,如果数据集已经处理了所有的属性,但是类标签依然不是唯一的。在这种情况下,我们通常会采用多数表决的方法决定叶子节点的分类。

#投票分类
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():classCount[vote]=0
        classCount[vote]+=1
    sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
    print(sortedClassCount)
    return sortedClassCount[0][0] #返回出现次数最多的分类

对树进行创建

#创建树
def createTree(dataSet,labels):
    #拿到所有数据的分类标签
    classList=[example[-1] for example in dataSet] 
    #类别完全相同则停止继续划分
    if classList.count(classList[0])==len(classList):
        return classList[0]
    #遍历完所有特征时返回出现次数最多的类别
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    bestFeat=chooseBestFeatureToSplit(dataSet,labels)#选择最好的划分特征,得到该特征的下标
    print(bestFeat)
    bestFeatLabel=labels[bestFeat]#得到最好特征的名称
    print(bestFeatLabel)
    #使用一个字典来存储树结构,分叉处为划分的特征名称
    myTree={bestFeatLabel:{}} 
    del(labels[bestFeat])#删除本次划分的特征值
    featValues=[example[bestFeat] for example in dataSet ]
    uniqueVals=set(featValues)
    for value in uniqueVals:
        #得到剩下的特征值
        subLabels=labels[:] 
        #递归调用
        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree

#print(createTree(data,labels))

我们得到一个字典类型的树。
 

7.使用Matplotlib注解绘制树形图

将树的结构可视化,有助于我们理解具体的分类过程。

import matplotlib.pyplot as plt
import matplotlib


# 能够显示中文
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['font.serif'] = ['SimHei']
#定义文本框和箭头格式
decisionNode=dict(boxstyle="sawtooth",fc='0.8') #分叉节点
leafNode=dict(boxstyle="round4",fc='0.8') #叶子节点
arrow_args=dict(arrowstyle="<-")

def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.axl.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
    xytext=centerPt,
    textcoords='axes fraction',
    va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)

#获取叶节点的数目和树的层数
def getNumLeafs(myTree):
    numLeafs=0#统计叶子节点总数
    firstStr=list(myTree.keys())[0]#得到根节点
    secondDict=myTree[firstStr]#第一个节点对应的内容
    for key in secondDict.keys(): #如果key对应的是一个字典,就递归调用
        if type(secondDict[key]).__name__=='dict':
            numLeafs+=getNumLeafs(secondDict[key])
        else:
            numLeafs+=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else:
            thisDepth=1
        if thisDepth>maxDepth:maxDepth=thisDepth
    return maxDepth


#计算出父节点和子节点的中间位置,填充信息
def plotMidText(cntrPt,parentPt,txtString):
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
    createPlot.axl.text(xMid,yMid,txtString)
#绘制出树的所有节点,递归绘制
def plotTree(mytree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(mytree)
    depth=getTreeDepth(mytree)
    firstStr=list(mytree.keys())[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=mytree[firstStr]
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD

#绘制决策树
def createPlot(inTree):
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    createPlot.axl=plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;plotTree(inTree,(0.5,1.0),'')
    plt.show()


展示结果

if __name__ == '__main__':
    mytree=createTree(data,labels)
    createPlot(mytree)

六.实验总结

本次实验只是对决策树的创建原理和创建算法以及展示创建的决策树进行了主要介绍。下一次实验,我们将会具体涉及到树的预剪枝、后剪枝、连续数据的离散化。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

机器学习--决策树 的相关文章

  • Vue基础精讲 —— 实例解析Vue的生命周期

    结合官网Vue生命周期图例 实例生命周期钩子 vue生命周期 import Vue from vue new Vue el root template div text div data text text beforeCreate 无法做
  • C语言计算机二级/C语言期末考试 刷题(四)

    收集了一些经典C语言计算机二级和C语言期末考试题库 整理不易 大家点赞收藏支持一下 祝大家计算机二级和期末考试都高分过 系列文章 C语言计算机二级 C语言期末考试 刷题 一 C语言计算机二级 C语言期末考试 刷题 二 C语言计算机二级 C语

随机推荐

  • dedecms织梦系统基本参数添加内容后首页显示后台空白不显示

    最近跟版网的小编在测试织梦模板时候发现 添加首页名称 版权信息 备案号等 保存后首页能显示 但是后台却没有文字 更新缓存也没用 那么这种情况该如何解决呢 小编告诉您 找到 dede templets sys info htm这个文件 里边找
  • syslog协议介绍

    syslog协议介绍 syslog架构 Unix Linux系统中的大部分日志都是通过一种叫做syslog的机制产生和维护的 syslog是一种标准的协议 分为客户端和服务器端 客户端是产生日志消息的一方 而服务器端负责接收客户端发送来的日
  • python unicodedecodeerror utf8_python问题,我运用python做中文词频分析的时候总是显示UnicodeDecodeError: 'utf-8'问题?...

    以下是我在python3 7idle中写的的语句importjiebatxt open E study pythondata 应用资料 三国演义 txt r encoding utf 8 read words jieba lcut txt
  • “Dependency ‘com.mysql:mysql-connector-j:‘ not found “等无法找到依赖问题解决

    在创建新的springboot项目时如果碰到 说明在该新建的项目中没有导入下列依赖 本人解决步骤 1 新建一个Maven工程 2 在该工程中加入自己想创建的springboot模块 3 将爆红的依赖复制粘贴进Maven项目中的pom xml
  • Mybatis查询where条件报 java.lang.IllegalStateException: range unbounded on this side解决方案

    Mybatis查询where条件报 java lang IllegalStateException range unbounded on this side解决方案 问题背景 解决方案 Lyric 就算是我不懂 问题背景 在使用id进行条件
  • docker-compose的使用

    一 docker compose命令 docker compose的使用非常类似于docker命令的使用 但是需要注意的是大部分的compose命令都需要到docker compose yml文件所在的目录下才能执行 docker comp
  • 数学建模笔记(三):数据预处理

    文章目录 前言 一 数据清洗 1 1 缺失值处理 1 2 异常值处理 二 数据变换 2 1 线性变换 2 2 向量规范化 2 3 min max归一化 2 4 z score标准化 三 数据预处理案例及代码实现 3 1 线性变换 代码实现
  • 【算法】算法学习四:图

    文章目录 一 什么是图 二 广度优先搜索 三 什么是队列 四 广度优先搜索的实现 4 1 实现全部的代码 4 2 队列的实现 五 深度优先搜索 六 图的运行时间 6 1 广度优先搜索 6 2 深度优先搜索 一 什么是图 在计算机科学中 图
  • Python静态方法和类方法的区别和应用(无师自通)

    实际上 Python 完全支持定义类方法 甚至支持定义静态方法 Python 的类方法和静态方法很相似 它们都推荐使用类来调用 其实也可使用对象来调用 类方法和静态方法的区别在于 Python会自动绑定类方法的第一个参数 类方法的第一个参数
  • 禅道后台命令执行漏洞二

    漏洞简介 禅道是第一款国产的开源项目管理软件 它集产品管理 项目管理 质量管理 文档管理 组织管理和事务管理于一体 是一款专业的研发项目管理软件 完整地覆盖了项目管理的核心流程 禅道管理思想注重实效 功能完备丰富 操作简洁高效 界面美观大方
  • vi编辑器的使用

    一 实验目的 理解vi的的三种运行模式及其切换方法 学会使用vi的各种操作命令进行文本文件编辑 用vi编写Linux下C程序 会用gcc编译 二 实验环境 一台装有Linux的机器 系统里面有gcc编译 三 实验内容 1 不保存直接退出 1
  • tomcat配置域名及默认访问页面

    1 配置80端口 在tomcat的conf server xml文件中的
  • 三子棋的实现--二维数组的应用

    通过对数组 函数 循环知识的应用我们可以独立地创建一个项目 三子棋 首先我们对于三子棋的实现要有一个大概的思路和逻辑 文件的创建 工欲善其事必先利其器 为了更好地完成项目 先创建三个文件 两个源文件 一个头文件 测试文件 test c 游戏
  • 使用druid-spring-boot-starter配合sharding报错

    在使用springboot时 为了方便配置 一般会使用启动器 不用单独进行 Bean 今天在增加sharding时一直出现找不到mapper的异常 对mapper加注解 加扫描包都不行 后来将druid spring boot starte
  • 安装指定版本nodejs

    要在Linux上安装指定版本的Node js 您可以使用Node Version Manager NVM NVM是一个用于管理多个Node js版本的工具 它允许您在同一系统上安装和切换不同的Node js版本 以下是使用NVM在Linux
  • 数字电源核心理论-“伏妙平衡“与“安秒平衡“

    数字电源 数字电源核心理论 伏妙平衡 与 安秒平衡 最后一个bug 2020 10 14 22 54 16 341 收藏 3 文章标签 编程语言 xhtml xmpp jrebel dwr 版权 1 聊一聊 今天跟大家分享的是迈克在本公众号
  • es批量增删改

    批量增删改 bulk 操作将文档的增删改查一系列操作 通过以此请求全部做完 减少网络传输次数 POST bulk 注意 bulk操作的形式是多个json 每个json写完必须换行 而在json内则不可以换行 多个json之间操作互不影响 即
  • C++标准模板库(STL)

    C 标准模板库 STL vector Introduction vector 长度根据需要而自动改变的数组 定义 vector
  • Oracle数据库插入大量数据

    insert into table name select rownum from dual connect by level lt 100 以上命令向表中插入了数列1 2 3 100
  • 机器学习--决策树

    一 决策树简介 决策树 DecisionTree 又称为判定树 是另一种特殊的根树 它最初是运筹学中的常用工具之一 之后应用范围不断扩展 目前是人工智能中常见的机器学习方法之一 决策树是一种基于树结构来进行决策的分类算法 我们希望从给定的训