本章介绍的决策树算法为ID3算法(Iterative Dichotomiser 3,迭代二叉树3代),主要流程为:根据信息增益找到划分数据的最佳特征——判断划分后每个数据子集是否为同一分类——若是,返回分类结果;若不是,再次划分数据子集(递归)。同时,本章利用Matplotlib的注解功能,将决策树可视化;利用pickle存储和读取决策树结构。
一、构造决策树
tree.py
1.计算信息熵
# 构造一个简单的“鱼鉴定”数据集
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels # 2个特征(标签)
# 测试结果
>>> import tree
>>> myDat,label=tree.createDataSet()
>>> myDat
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
# 计算给定数据集的香农熵
# first step:创建字典labelCounts={标签:标签出现次数}
# second step:根据公式求香农熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet) # dataSet的个数
labelCounts = {} # 创建一个{标签:标签出现次数}的字典
for featVec in dataSet:
currentLabel = featVec[-1] # dataSet最后一列,即标签
if currentLabel not in list(labelCounts.keys()): # Python 3.x中,dict.keys()返回值不是list
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 # 统计每个标签出现次数
shannonEnt = 0.0
for key in labelCounts: # 计算香农熵
prob = float(labelCounts[key])/numEntries # 每个类别标签出现的概率
shannonEnt -= prob * log(prob, 2) # 以2为底求对数
return shannonEnt
# 测试结果
>>> tree.calcShannonEnt(myDat)
0.9709505944546686
# 修改数据集内容
>>> myDat[0][-1]='maybe'
>>> myDat
[[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
>>> tree.calcShannonEnt(myDat)
1.3709505944546687
信息(information)
如果待分类的事务可能划分在多个分类之中,则符号
x
i
x_i
xi的信息定义为:
l
(
x
i
)
=
−
l
o
g
2
p
(
x
i
)
l(x_i)=-log_2p(x_i)
l(xi)=−log2p(xi)其中,
p
(
x
i
)
p(x_i)
p(xi)为选择改分类的概率。可见某事务概率越高,其所包含的信息越少。
熵(entropy)
信息的期望值定义为熵,即所有类别的所有可能值所包含的信息的期望:
H
=
−
∑
l
=
1
n
p
(
x
i
)
l
o
g
2
p
(
x
i
)
H=-\sum_{l=1}^{n} p(x_i)log_2p(x_i)
H=−l=1∑np(xi)log2p(xi)其中,
n
n
n是分类的数目。
信息增益(information gain)
通俗理解,信息增益即某项特征能为分类系统,增加多少信息。信息增益越大,该特征越‘重要’。
dict.keys()
dict.keys()
返回一个字典dict的所有键。
Python 3.x中,如果直接使用dict.keys(),那么返回值为dict_keys,并非直接的列表,若要返回列表值还需调用list函数。
>>> dict = {'Name': 'Zara', 'Age': 7}
>>> dict.keys()
dict_keys(['Name', 'Age'])
>>> list(dict.keys())
['Name', 'Age']
‘/’ 和 ‘//’ 的区别
实验环境:python 3.7
/ :float除法。无论除数、被除数是类型,结果都为float,且为精确结果。
//:整除。又称地板除,即舍弃小数部分向下取整(注意负数),除数、被除数任一为float,结果即为float。
>>> 3/2
1.5
>>> 3//2
1
>>> 3/2.0
1.5
>>> 3//2.0
1.0
>>> -3//2 # 向下取整
-2
2.划分数据集
# 按照给定特征划分数据集
# first step:创建新的list
# second step:判断满足条件的示例,抽取除划分特征外剩余的list内容
# third step:将抽取的list append到retDataSet中,形成新的list
def splitDataSet(dataSet, axis, value): # 给定(划分数据集,划分特征所在列,划分特征值)
retDataSet = [] # 划分后的新数据集
for featVec in dataSet:
if featVec[axis] == value: # 划分特征值等于设定值
reducedFeatVec = featVec[:axis] # 截取划分特征所在列前面的特征
reducedFeatVec.extend(featVec[axis+1:]) # 添加划分特征所在列后面的特征(extend函数会将所有满足条件的特征列作为新的元素加入列表中)
retDataSet.append(reducedFeatVec) # 将重新‘组合’的列表整个加入retDataSet中
return retDataSet
此代码段中用到extend()函数和append()函数两个扩展函数,二者有所区别,参考。
# 测试结果
>>> myDat,label=tree.createDataSet()
>>> myDat
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
>>> tree.splitDataSet(myDat,0,1) #按myDat数据集中,第0个特征值为1划分
[[1, 'yes'], [1, 'yes'], [0, 'no']] #此三示例第0个特征值为1
>>> tree.splitDataSet(myDat,0,0)
[[1, 'no'], [1, 'no']]
set()
parame = {value01,value02,...}
set(value)
# 当创建一个空集合时,只能用第二种方式,用第一种{ }创建的是一个空字典
集合(set)是一个无序的不重复元素序列。
# 1.去重功能
>>>basket = {'apple', 'orange', 'apple', 'pear', 'orange', 'banana'}
>>> print(basket) # 这里演示的是去重功能
{'orange', 'banana', 'pear', 'apple'}
>>> 'orange' in basket # 快速判断元素是否在集合内
True
>>> 'crabgrass' in basket
False
# 2.两个set间的操作
>> a = set('abracadabra')
>>> b = set('alacazam')
>>> a
{'a', 'r', 'b', 'c', 'd'}
>>> a - b # 集合a中包含而集合b中不包含的元素
{'r', 'd', 'b'}
>>> a | b # 集合a或b中包含的所有元素(或)
{'a', 'c', 'r', 'd', 'b', 'm', 'z', 'l'}
>>> a & b # 集合a和b中都包含了的元素(与)
{'a', 'c'}
>>> a ^ b # 不同时包含于a和b的元素(异或)
{'r', 'd', 'b', 'm', 'z', 'l'}
# 选择最好的数据集划分方式,根据最大信息增益的原则
# first step:创建分类特征列表uniqueVals
# second step:计算每种划分方式的信息熵
# third step:计算最好的信息增益,返回对应的划分方式
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 # 待划分数据集最后一列为类别,其余列为特征feature
baseEntropy = calcShannonEnt(dataSet) # 计算未划分数据集时的香农熵
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet] # dataSet里每一个example的第i个特征的值
# ⬆这一句展开来如下:
# for example in dataSet:
# featList = []
# featList.extend([example[i]])
# # featList.extend(example[i]) # 注意这么写是不对的,这么写featList无法被set调用
uniqueVals = set(featList) # set()可以去重,保证后面不会重复划分dataSet
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet) # 按照某个特征划分数据后的熵
infoGain = baseEntropy - newEntropy # 计算信息增益
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
# 测试结果
>>> myDat
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
>>> tree.chooseBestFeatureToSplit(myDat)
0 #即按照第0个特征划分
3.递归构建决策树
构建决策树的流程大致为:得到原始数据集——基于最好的特征属性值划分数据集。由于特征值可能多于两个,所以存在多次划分,即需要递归处理。
递归结束的条件有二:一是程序遍历完所有划分数据集的特征属性,二是每个分支下的所有示例都具有相同的分类(我理解的是两个条件同时满足,递归才终止)。但有时数据集已经处理了所有特征属性,类标签仍不唯一,此时就需要决定如何定义该叶子节点。如下采用多数表决的方法。
# 创建树函数
# fist:特征类别完全相同停止继续划分,返回特征类别
# second:遍历完所有特征,返回出现次数最多的特征类别
# third:创建字典树,存储当前最好特征属性值
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet] # classList表示类别标签
if classList.count(classList[0]) == len(classList): # 若所有类别标签相同(classList中第0个元素的个数==classList的长度)
return classList[0] # 则不用分类,直接返回类别。递归停止。
if len(dataSet[0]) == 1: # 若使用完了所有特征,仍不能将数据集划分为仅含一类的分组
return majorityCnt(classList) # 则调用majorityCnt函数投票。递归停止
bestFeat = chooseBestFeatureToSplit(dataSet) # 根据最大信息增益的原理选择最佳划分特征(特征索引,如第0个特征、第1个特征等)
bestFeatLabel = labels[bestFeat] # 最佳划分特征的标签
myTree = {bestFeatLabel:{}} # 构造字典树
subLabels = labels[:] # 复制labels给subLabels,防止破坏原始labels
# 这里的赋值源程序放在下面的循环里,没卵用啊...labels还是会变啊...所以放到了循环外
del(subLabels[bestFeat]) # 去除当前最佳特征的标签
featValues = [example[bestFeat] for example in dataSet] # 当前最佳特征对应的取值
uniqueVals = set(featValues)
for value in uniqueVals:
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
# 递归调用。先要调用splitDataSet函数,获得subDataSet。
return myTree
# 测试结果
>>> myDat,labels=tree.createDataSet()
>>> myTree=tree.createTree(myDat,labels)
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
python中,函数参数为列表类型
createTree(dataSet, labels)
函数中包含dataSet和labels两个参数,均为列表类型,且代码中重复调用labels参数,需要‘去除当前最佳特征的标签’,会改变原始labels的值。所以先将labels复制:subLabels = labels[:]
,再去除:del(subLabels[bestFeat])
。
list.count()
list.count(obj)
返回列表list中某个元素obj出现的次数。
>>> aList = [123, 'xyz', 'zara', 'abc', 123]
>>> aList.count(123)
2
>>> aList.count('zara')
1
二、绘制决策树(Matplotlib)
treePlotter.py
1.Matplotlib注解
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif'] = ['KaiTi'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号'-'显示为方块的问题
# 定义决策节点、叶节点以及标记线
decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 决策节点
leafNode = dict(boxstyle="round4", fc='0.8') # 叶节点
arrow_args = dict(arrowstyle="<-") # 标记线(起点为箭头尖端)
# arrow_args = dict(arrowstyle="->") # 标记线(起点为箭头线段端)——默认
dict()
class dict(**kwarg) # 传入关键字
class dict(mapping, **kwarg) # 利用某种mapping(元素容器)构造字典
class dict(iterable, **kwarg) # 利用可迭代对象构造字典
用于创建一个字典。
>>>dict() # 创建空字典
{}
>>> dict(a='a', b='b', t='t') # 传入关键字
{'a': 'a', 'b': 'b', 't': 't'}
>>> dict(zip(['one', 'two', 'three'], [1, 2, 3])) # 映射函数方式来构造字典
{'three': 3, 'two': 2, 'one': 1}
>>> dict([('one', 1), ('two', 2), ('three', 3)]) # 可迭代对象方式来构造字典
{'three': 3, 'two': 2, 'one': 1}
# 绘制节点,给定(节点文本内容,节点框中心坐标(标记线起点坐标),标记线终点坐标,绘制类型)
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
# 个人理解:createPlot.ax1相当于一个全局的变量,在createPlot()函数中创建,在plotNode()函数中调用
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# node_txt: 节点文本内容
# xy: 标记线终点坐标(此例中标记线为箭头,终点为箭头线段端)
# xycoords:标记线的坐标原点位置,是以图像还是坐标轴
# xytext: 节点框的中心坐标
# textcoords:节点框的坐标原点位置,是以图像还是坐标轴
# va: vertical alignment 文本中的内容 横向对齐方式
# ha: horizontal alignment 文本中内容竖向对齐方式
# bbox: 节点框的设置
# arrowprops: 标记线的类型,是一个字典,arrowstyle未指定,则默认类别为'->'
createPlot.ax1.annotate():
参考:1&2.
# 绘制
def createPlot():
fig = plt.figure(1, facecolor='white')
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon=False)
plotNode("决策节点", (0.5,0.1), (0.1,0.5), decisionNode)
plotNode("叶节点", (0.8,0.1), (0.3,0.8), leafNode)
plt.show()
# 测试结果
>>> import treePlotter
>>> treePlotter.createPlot()
2.构造注解树
为了绘制一颗决策树,要知道如何放置树的节点。所以定义下面两个函数,分别用来获取叶节点的个数(x轴长度)和树的层数(y轴高度):
# 获取叶节点数目,以便确定x轴长度
# 叶节点:得到类别的节点;决策节点:需要进一步判断的节点
# 思路:判断该节点对应值的数据类型是否为‘字典’,若是,该节点为决策节点,递归;若不是,该节点为叶节点
def getNumLeafs(myTree):
numLeafs = 0
# firstStr为第一个决策节点
firstStr = list(myTree.keys())[0] #py3.*中,dict.keys()的值类型不带下标,需转化成list
# firstStr为第一个决策节点对应的值,即余下的节点
secondDict = myTree[firstStr]
for key in secondDict.keys(): # 在余下的节点中遍历
if type(secondDict[key]).__name__=='dict': # 判断第二个节点是否仍为字典类型
numLeafs += getNumLeafs(secondDict[key]) # 若是,该节点为决策节点,递归调用getNumLeafs()
else:
numLeafs += 1 #若不是,该节点为叶节点
return numLeafs
# 获取树的层数(决策节点的个数),以便确定y轴长度
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth # 取遍历完后的最大值
return maxDepth
为节省时间,预存储两个树的信息:
# 预先存储树信息,避免每次测试都要创建树
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return listOfTrees[i]
# 测试结果
>>> treePlotter.retrieveTree(1)
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
>>> myTree=treePlotter.retrieveTree(0)
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
>>> treePlotter.getNumLeafs(myTree)
3
>>> treePlotter.getTreeDepth(myTree)
2
3.绘制决策树
# 在父子节点间填充文本信息
# 给定(标记线起点坐标(即当前节点框中心坐标),标记线终点坐标(即上一个节点框中心坐标),文本信息)
# 返回父子节点的中间位置坐标
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
plt.text()
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree) # 获取子节点个数
depth = getTreeDepth(myTree) # 获取树的层数
firstStr = list(myTree.keys())[0]
# cntPt为叶节点中心坐标
cntrPt = (plotTree.xOff + (1.0+float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) # ???这里怎么计算
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree)) #全局变量plotTree.totalW,储存树的宽度
plotTree.totalD = float(getTreeDepth(inTree)) #全局变量PlotTree.totalD,储存树的高度
plotTree.xOff = -0.5/plotTree.totalW #全局变量plotTree.xOff,追踪已绘制的节点位置,从左到右,即递增
plotTree.yOff = 1.0 #全局变量plotTree.yOff,追踪已绘制的节点位置,从上到下,即递减
plotTree(inTree, (0.5,1.0), '') # 在(0.5,1.0)的位置绘制第一个节点(父节点/根节点)
plt.show()
三、测试和存储分类器
使用决策树执行分类
# 使用决策树的分类函数
# first step: 使用index方法查找当前列表中第一个匹配firsStr的元素
# second step: 递归遍历整棵树,比较testVec中的值和树节点的值,若为决策节点(字典类型),递归;否则即为叶节点,返回该节点类别标签
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree.keys())[0] # 决策树第一个特征
secondDict = inputTree[firstStr] # 决策树第一个特征对应的值
featIndex = featLabels.index(firstStr) # 决策树第一个特征在featLabls中的索引(第0个特征/第1个特征等)
for key in secondDict.keys(): # 遍历余下的特征键值,找到和测试数据相等的key
if testVec[featIndex] == key: # 若testVec中第featIndex个值等于key
if type(secondDict[key]).__name__=='dict': # 若还是字典,递归
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key] # 否则直接返回key对应的类别标签
return classLabel
# 测试结果
>>> myDat,labels=tree.createDataSet()
>>> myTree=tree.createTree(myDat,labels)
>>> myDat
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
>>> labels
['no surfacing', 'flippers']
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
>>> tree.classify(myTree,labels,[1,1])
'yes'
index()
list.index(obj)
返回某个对象obj在列表list中的索引值。
>>> list1 = ['Google', 'Runoob', 'Taobao']
>>> list1.index("Google")
0
>>> list1.index("Taobao")
2
2.使用算法:决策树的存储
构造决策树是很耗时的任务,但是用构造好的决策树解决分类问题可以很快完成。所以最好能在每次执行分类时调用已经构造好的决策树,这就需要将决策树存储在磁盘中。
# 使用pickle模块存储决策树
def storeTree(inputTree, filename):
import pickle
# fw = open(filename, 'w') #TypeError: write() argument must be str, not bytes
fw = open(filename, 'wb')
pickle.dump(inputTree, fw)
fw.close()
# 使用pickle模块读取决策树
def grabTree(filename):
import pickle
# fr = open(filename) # UnicodeDecodeError: 'gbk' codec can't decode byte 0x80 in position 0: illegal multibyte sequence
fr = open(filename, 'rb')
return pickle.load(fr)
# 测试结果
>>> tree.storeTree(myTree,'classifierStorage.txt')
>>> tree.grabTree('classifierStorage.txt')
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
pickle模块
用于序列化对象,可以将python中几乎所有的数据类型(列表,字典,集合,类等)序列化,将其作为文件储存。pickle序列化后的数据,可读性差,人一般无法识别。参考。
pickle.dump(obj, file[, protocol])
序列化对象,并将结果数据流写入到文件对象中。参数protocol是序列化模式,默认值为0,表示以文本的形式序列化。protocol的值还可以是1或2,表示以二进制的形式序列化。
pickle.load(file)
反序列化对象。将文件中的数据解析为一个Python对象。
3.示例:使用决策树预测隐形眼镜类型
import tree
import treePlotter
fr = open('lenses.txt') # 读取隐形眼镜数据集
lenses = [inst.strip().split('\t') for inst in fr.readlines()] # 将文本数据集转化为列表
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = tree.createTree(lenses, lensesLabels)
print(lensesTree)
treePlotter.createPlot(lensesTree)
# output:
{'tearRate': {'normal': {'astigmatic': {'no': {'age': {'young': 'soft', 'presbyopic': {'prescript': {'myope': 'no lenses', 'hyper': 'soft'}}, 'pre': 'soft'}}, 'yes': {'prescript': {'myope': 'hard', 'hyper': {'age': {'young': 'hard', 'presbyopic': 'no lenses', 'pre': 'no lenses'}}}}}}, 'reduced': 'no lenses'}}
总结
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。
缺点:可能会产生过度匹配问题。例如隐形眼镜的示例中,匹配选项过多,可以考虑裁剪部分叶节点(如果该节点不能产生太多信息)。
适用数据类型:标称型和数值型(需要离散化)。