第三章主要讲述的内容是决策树的源码和利用Matplotlib绘制树形图
决策树的核心是使用递归的逐级分类:
def createBranch()
If so return 类标签;
Else
寻找划分数据集的最好特征(香农熵)
划分数据集
创建分支节点
for 每个划分的子集:
调用函数createBranch()并增加返回结果到分支节点中
return 分支节点
分类是基于香农熵的计算,香农熵表示信息增益的大小。代码和注释如下
def calcShannonEnt(dataset):
numEntries = len(dataset)
labelCounts = {}
#()代表tuple元祖数据类型,元祖是一种不可变序列
#[]代表list列表数据类型,列表是一种可变序列
#{}代表字典数据类型,键值对
for featVec in dataset:
currentLabel = featVec[-1]#取最后一个元素
if currentLabel not in labelCounts.keys():#keys是字典的所有键,这一部分是计数的,
labelCounts[currentLabel] = 0#这一步已经创建了一个字典里面的键,=后面是赋值
labelCounts[currentLabel] += 1
#循环执行后labelCounts{'yes':2,'no':3}
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob,2)
return shannonEnt
分类数据集可任意给出,书本作为参考。给出时注意所有数据的分类信息和标签。
def createDataset():
dataSet = [[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels = ['no surfacing','flippers']
return dataSet,labels
下面是先划分数据集的代码,根据你要划分的依据来划分已知的数据集。
def splitDataSet(dataSet,axis,value):#划分数据集,三个参数分别是待划分的数据集、划分数据集的特征、需要返回的特征的值
retDatSet = []#list类型
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])#extend是把括号里面的一个一个加进去
retDatSet.append(reducedFeatVec)#append是把括号里面的当成一个元素加进去
return retDatSet
#测试代码如下
# myData,labels = createDataset()
# test = splitDataSet(myData,0,1)
# print(test)
#控制台输出:[[1, 'yes'], [1, 'yes'], [0, 'no']],返回的是0号特征为1的所有数据的信息
由于数据集有多种特征,这就涉及到特征的选择,哪些特征最“可区分”(即信息熵最大)
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1#特征数量
baseEntropy = calcShannonEnt(dataSet)#返回的是信息熵的值,计算的是整个数据集的原始香农熵
bestInfoGain = 0.0;bestFeature = -1
for i in range(numFeatures):#遍历数据集中所有特征
featList = [example[i] for example in dataSet]#取dataset里面每个元素的第一个
uniqueVals = set(featList)#根据括号里面的东西创建一个无序不重复元素集
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet,i,value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature#总体来说是判断哪一个特征的信息熵更大,返回最大信息变化的坐标
#测试代码
# myData,labels = createDataset()
# test = chooseBestFeatureToSplit(myData)
# print(test)
#控制台输出:0,表示0号特征应当作为划分特征
上面介绍了决策树的划分原理,下面按照决策过程介绍构建决策树的代码。
首先要了解决策树的执行过程:根据之前对决策节点的划分,越走越深,直到遍历完叶子节点,最后将每个叶子节点的所有数据分类为相同的
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]#利用operator操作键值排序字典,并返回出现次数最多的分类名称
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]#-1表示最后一个元素,这里用的是一个小循环,取出了所有数据的最后一个标签,即['yes','yes','no','no','no']
if classList.count(classList[0]) == len(classList):#count() 方法用于统计字符串里某个字符或子字符串出现的次数
return classList[0]#如果list里面都是一样的东西
if len(dataSet[0]) == 1:
return majorityCnt(classList)#如果dataset里面只有一个数据
bestFeat = chooseBestFeatureToSplit(dataSet)#返回的是信息熵最大特征的坐标
bestFeatLabel = labels[bestFeat]#最适合分类的label名
myTree = {bestFeatLabel:{}}#字典是一对:key, value的键值对
del(labels[bestFeat])#del用于list列表操作,删除一个或者连续几个元素,这里删除了分类的label
featValues = [example[bestFeat] for example in dataSet]#bestfeat是0,返回dataset里面第一个的值
uniqueVals = set(featValues)#创建一个无序不重复元素集,可进行关系测试,删除重复数据,还可以计算交集、差集、并集等
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)#递归调用createTree函数,去除了信息熵最大的特征,用其他的特征继续输入这个函数,进行二级分类
#用字典的嵌套,字典里面的每一个值是一个字典
return myTree
#最后达到的目的是创建了字典型的树
#测试代码
# myDat,labels = createDataset()
# myTree = createTree(myDat,labels)
# print(myTree)
#控制台输出:{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},是嵌套字典的形式
下一本书介绍绘制树形图的过程,后面会介绍。下面介绍使用决策树对输入数据进行分类。
在进行数据分类时,需要输入的参数是决策树,以及用于构建树的标签向量和用于分类的测试向量。然后,程序将测试数据与决策树上的值进行比较,递归执行该过程,直到进入叶子节点;最后将测试数据定义为叶子节点所属的类型。决策树分类函数代码如下
def classify(inputTree,featLabels,testVec):#输入的参数分别是,原有的决策树、决策标签和用于分类的数据
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)#index() 方法检测字符串中是否包含子字符串 str
for key in secondDict.keys():
if testVec[featIndex] == key:#testVec是list类型的
if type(secondDict[key]).__name__=='dict':
classLabel = classify(secondDict[key],featLabels,testVec)#递归调用
else:
classLabel = secondDict[key]
return classLabel#返回的是分类的信息,yes或者no,判断海洋生物是不是鱼类
#第一节点名为no surfacing,它有两个子节点:一个是名字为0的叶子节点,类标签为no;另一个是名为flippers的判断节点,此处进入递归调用,flippers节点有两个子节点。
#测试代码
# myDat,labels = createDataset()
# myTree = treePlotter.retriveTree(0)
# test = classify(myTree,labels,[1,0])
# test1 = classify(myTree,labels,[1,1])
# print(test,test1)
# 控制台输出:no yes
这里有一个需要注意的点,list(inputTree.keys())[0]和书中的inputTree.keys()[0]不同,因为书中是py2,.keys返回的是list类型的数据,但是py3中返回的就是dict_keys类型,需要进行数据转换。
介绍完数据处理,接下来就是数据存储了
def storeTree(inputTree,filename):
import pickle
fw = open(filename,'wb')
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
#测试代码
storeTree(myTree,'classifierStorage.txt')
grabTree('classifierStorage.txt')
store代码可以正常运行,但是生成的txt文件是乱码,目前还没找到原因
文章出处登录后可见!
已经登录?立即刷新