From a4bcdf74c3c5824c405b3c1114500dfc548bfce1 Mon Sep 17 00:00:00 2001 From: jiangzhonglian Date: Tue, 28 Feb 2017 19:05:27 +0800 Subject: [PATCH] =?UTF-8?q?=E5=86=B3=E7=AD=96=E6=A0=91=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E6=A1=88=E4=BE=8B=E6=9B=B4=E6=96=B0=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/3.决策树.md | 6 +- .../DecisionTree.py} | 119 ++++++++++++---- .../03.DecisionTree/DecisionTreePlot.py | 132 ++++++++++++++++++ 3 files changed, 231 insertions(+), 26 deletions(-) rename src/python/{03.DecisionTree.py => 03.DecisionTree/DecisionTree.py} (54%) create mode 100644 src/python/03.DecisionTree/DecisionTreePlot.py diff --git a/docs/3.决策树.md b/docs/3.决策树.md index f91b40bf..4a2db238 100644 --- a/docs/3.决策树.md +++ b/docs/3.决策树.md @@ -1,5 +1,6 @@ # 3) 决策树 + * 决策树是什么? * 顾名思义,是一种树,一种依托于策略抉择而建立起来的树。 @@ -17,7 +18,8 @@ * 划分数据集的最大原则是:将无序的数据变得更加有序。 * 集合信息的度量称为`香农熵`或者简称`熵`(名字来源于信息论之父`克劳德·香农`) * 公式: - * l(x_i) = -log_2 P(x_i) - * + * \\(p(x_i)\\) 表示该label分类的概率 + * \\(l(x_i) = - \log_2p(x_i)\\) 表示符号\\(x_i\\)的信息定义 + * \\(H = -\sum_{i=0}^np(x_i)\log_2p(x_i)\\) 表示香农熵,用于计算信息熵 * 基尼不纯度(Gini impurity) [本书不做过多的介绍] * 简单来说:就是从一个数据集中随机选取子项,度量其被错误分类到其他分组里的概率。 diff --git a/src/python/03.DecisionTree.py b/src/python/03.DecisionTree/DecisionTree.py similarity index 54% rename from src/python/03.DecisionTree.py rename to src/python/03.DecisionTree/DecisionTree.py index e958eef8..c7920bd1 100644 --- a/src/python/03.DecisionTree.py +++ b/src/python/03.DecisionTree/DecisionTree.py @@ -9,6 +9,7 @@ Decision Tree Source Code for Machine Learning in Action Ch. 3 ''' from math import log import operator +import DecisionTreePlot as dtPlot def createDataSet(): @@ -26,13 +27,18 @@ def createDataSet(): [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] + # dataSet = [['yes'], + # ['yes'], + # ['no'], + # ['no'], + # ['no']] labels = ['no surfacing', 'flippers'] # change to discrete values return dataSet, labels def calcShannonEnt(dataSet): - """calcShannonEnt(calculate Shannon entropy 计算香农熵) + """calcShannonEnt(calculate Shannon entropy 计算label分类标签的香农熵) Args: dataSet 数据集 @@ -61,83 +67,136 @@ def calcShannonEnt(dataSet): prob = float(labelCounts[key])/numEntries # log base 2 shannonEnt -= prob * log(prob, 2) - print '---', prob, prob * log(prob, 2), shannonEnt + # print '---', prob, prob * log(prob, 2), shannonEnt return shannonEnt def splitDataSet(dataSet, axis, value): + """splitDataSet(通过遍历dataSet数据集,求出axis对应的colnum列的值为value的行) + + Args: + dataSet 数据集 + axis 表示每一行的axis列 + value 表示axis列对应的value值 + Returns: + axis列为value的数据集【该数据集需要排除axis列】 + Raises: + + """ retDataSet = [] for featVec in dataSet: + # axis列为value的数据集【该数据集需要排除axis列】 if featVec[axis] == value: # chop out axis used for splitting reducedFeatVec = featVec[:axis] + ''' + 请百度查询一下: extend和append的区别 + ''' reducedFeatVec.extend(featVec[axis+1:]) + # 收集结果值 axis列为value的行【该行需要排除axis列】 retDataSet.append(reducedFeatVec) return retDataSet def chooseBestFeatureToSplit(dataSet): - # the last column is used for the labels + """chooseBestFeatureToSplit(选择最好的特征) + + Args: + dataSet 数据集 + Returns: + bestFeature 最优的特征列 + Raises: + + """ + # 求第一行有多少列的 Feature numFeatures = len(dataSet[0]) - 1 + # label的信息熵 baseEntropy = calcShannonEnt(dataSet) - bestInfoGain = 0.0 - bestFeature = -1 + # 最优的信息增益值, 和最优的Featurn编号 + bestInfoGain, bestFeature = 0.0, -1 # iterate over all the features for i in range(numFeatures): # create a list of all the examples of this feature + # 获取每一个feature的list集合 featList = [example[i] for example in dataSet] # get a set of unique values - uniqueVals = set(featList) + # 获取剔重后的集合 + uniqueVals = set(featList) + # 创建一个临时的信息熵 newEntropy = 0.0 + # 遍历某一列的value集合,计算该列的信息熵 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet)/float(len(dataSet)) - newEntropy += prob * calcShannonEnt(subDataSet) - infoGain = baseEntropy - newEntropy #calculate the info gain; ie reduction in entropy - if (infoGain > bestInfoGain): #compare this to the best gain so far - bestInfoGain = infoGain #if better than current best, set to best + newEntropy += prob * calcShannonEnt(subDataSet) + # 计算label的信息熵和每个特征的信息熵 的增益值,如果增益值大于最大值,那么效果越好 + infoGain = baseEntropy - newEntropy + if (infoGain > bestInfoGain): + bestInfoGain = infoGain bestFeature = i - return bestFeature #returns an integer + return bestFeature def majorityCnt(classList): + """majorityCnt(选择出线次数最多的一个结果) + + Args: + classList label列的集合 + Returns: + bestFeature 最优的特征列 + Raises: + + """ classCount = {} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 + # 倒叙排列classCount得到一个字典集合,然后取出第一个就是结果(yes/no) sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) + # print 'sortedClassCount:', sortedClassCount return sortedClassCount[0][0] def createTree(dataSet, labels): classList = [example[-1] for example in dataSet] + # 如果数据集的最后一列的第一个值出现的次数=整个集合的数量,也就说只有一个类别,就只直接返回结果就行 if classList.count(classList[0]) == len(classList): - return classList[0]#stop splitting when all of the classes are equal - if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet + return classList[0] + # 如果数据集只有1列,那么最初出现label次数最多的一类,作为结果 + if len(dataSet[0]) == 1: return majorityCnt(classList) + + # 选择最优的列,得到最有列对应的label含义 bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] - myTree = {bestFeatLabel:{}} + # 初始化myTree + myTree = {bestFeatLabel: {}} # 注:labels列表是可变对象,在PYTHON函数中作为参数时传址引用,能够被全局修改 # 所以这行代码导致函数外的同名变量被删除了元素,造成例句无法执行,提示'no surfacing' is not in list del(labels[bestFeat]) + # 取出最优列,然后它的branch做分类 featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: - subLabels = labels[:] #copy all of labels, so trees don't mess up existing labels - myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels) + # 求出剩余的标签label + subLabels = labels[:] + myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) + # print 'myTree', value, myTree return myTree def classify(inputTree, featLabels, testVec): - # 获取tree的第一个节点值 - print '1111', inputTree.keys() + # 获取tree的第一个节点对应的key值 firstStr = inputTree.keys()[0] + # 获取第一个节点对应的value值 secondDict = inputTree[firstStr] + # 判断根节点的索引值,然后根据testVec来获取对应的树分枝位置 featIndex = featLabels.index(firstStr) key = testVec[featIndex] valueOfFeat = secondDict[key] + print '+++', firstStr, 'xxx', secondDict, '---', key, '>>>', valueOfFeat + # 判断分枝是否结束 if isinstance(valueOfFeat, dict): classLabel = classify(valueOfFeat, featLabels, testVec) else: @@ -145,7 +204,7 @@ def classify(inputTree, featLabels, testVec): return classLabel -def storeTree(inputTree,filename): +def storeTree(inputTree, filename): import pickle fw = open(filename, 'w') pickle.dump(inputTree, fw) @@ -162,11 +221,23 @@ if __name__ == "__main__": # 1.创建数据和结果标签 myDat, labels = createDataSet() - print myDat, labels + # print myDat, labels - calcShannonEnt(myDat) + # # 计算label分类标签的香农熵 + # calcShannonEnt(myDat) - # import copy - # myTree = createTree(myDat, copy.deepcopy(labels)) - # print myTree + # # 求第0列 为 1/0的列的数据集【排除第0列】 + # print '1---', splitDataSet(myDat, 0, 1) + # print '0---', splitDataSet(myDat, 0, 0) + + # # 计算最好的信息增益的列 + # print chooseBestFeatureToSplit(myDat) + + import copy + myTree = createTree(myDat, copy.deepcopy(labels)) + print myTree + # [1, 1]表示要取的分支上的节点位置,对应的结果值 # print classify(myTree, labels, [1, 1]) + + # 画图可视化展现 + dtPlot.createPlot(myTree) diff --git a/src/python/03.DecisionTree/DecisionTreePlot.py b/src/python/03.DecisionTree/DecisionTreePlot.py new file mode 100644 index 00000000..cf11e382 --- /dev/null +++ b/src/python/03.DecisionTree/DecisionTreePlot.py @@ -0,0 +1,132 @@ +#!/usr/bin/python +# coding:utf8 + +''' +Created on Oct 14, 2010 +Update on 2017-02-27 +Decision Tree Source Code for Machine Learning in Action Ch. 3 +@author: Peter Harrington/jiangzhonglian +''' +import matplotlib.pyplot as plt + +# 定义文本框 和 箭头格式 【 sawtooth 波浪方框, round4 矩形方框 , fc表示字体颜色的深浅 0.1~0.9 依次变浅,没错是变浅】 +decisionNode = dict(boxstyle="sawtooth", fc="0.8") +leafNode = dict(boxstyle="round4", fc="0.8") +arrow_args = dict(arrowstyle="<-") + + +def getNumLeafs(myTree): + numLeafs = 0 + firstStr = myTree.keys()[0] + secondDict = myTree[firstStr] + # 根节点开始遍历 + for key in secondDict.keys(): + # 判断子节点是否为dict, 不是+1 + if type(secondDict[key]).__name__ == 'dict': + numLeafs += getNumLeafs(secondDict[key]) + else: + numLeafs += 1 + return numLeafs + + +def getTreeDepth(myTree): + maxDepth = 0 + firstStr = myTree.keys()[0] + secondDict = myTree[firstStr] + # 根节点开始遍历 + for key in secondDict.keys(): + # 判断子节点是不是dict, 求分枝的深度 + if type(secondDict[key]).__name__ == 'dict': + thisDepth = 1 + getTreeDepth(secondDict[key]) + else: + thisDepth = 1 + # 记录最大的分支深度 + if thisDepth > maxDepth: + maxDepth = thisDepth + return maxDepth + + +def plotNode(nodeTxt, centerPt, parentPt, nodeType): + createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args) + + +def plotMidText(cntrPt, parentPt, txtString): + xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] + yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] + createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) + + +def plotTree(myTree, parentPt, nodeTxt): + # 获取叶子节点的数量 + numLeafs = getNumLeafs(myTree) + # 获取树的深度 + # depth = getTreeDepth(myTree) + + # 找出第1个中心点的位置,然后与 parentPt定点进行划线 + cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) + # print cntrPt + # 并打印输入对应的文字 + plotMidText(cntrPt, parentPt, nodeTxt) + + firstStr = myTree.keys()[0] + # 可视化Node分支点 + plotNode(firstStr, cntrPt, parentPt, decisionNode) + # 根节点的值 + secondDict = myTree[firstStr] + # y值 = 最高点-层数的高度[第二个节点位置] + plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD + for key in secondDict.keys(): + # 判断该节点是否是Node节点 + if type(secondDict[key]).__name__=='dict': + # 如果是就递归调用[recursion] + plotTree(secondDict[key],cntrPt,str(key)) + else: + # 如果不是,就在原来节点一半的地方找到节点的坐标 + plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW + # 可视化该节点位置 + plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) + # 并打印输入对应的文字 + plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) + # plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD + + +def createPlot(inTree): + # 创建一个figure的模版 + fig = plt.figure(1, facecolor='green') + fig.clf() + + axprops = dict(xticks=[], yticks=[]) + # 表示创建一个1行,1列的图,createPlot.ax1 为第 1 个子图, + createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) + + plotTree.totalW = float(getNumLeafs(inTree)) + plotTree.totalD = float(getTreeDepth(inTree)) + # 半个节点的长度 + plotTree.xOff = -0.5/plotTree.totalW + plotTree.yOff = 1.0 + plotTree(inTree, (0.5, 1.0), '') + plt.show() + + +# # 测试画图 +# def createPlot(): +# fig = plt.figure(1, facecolor='white') +# fig.clf() +# # ticks for demo puropses +# createPlot.ax1 = plt.subplot(111, frameon=False) +# plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode) +# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode) +# plt.show() + + +# 测试数据集 +def retrieveTree(i): + listOfTrees =[ + {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, + {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}} + ] + return listOfTrees[i] + + +myTree = retrieveTree(0) +createPlot(myTree)