#!/usr/bin/python # coding:utf8 ''' Created on Oct 14, 2010 Update on 2017-02-27 Decision Tree Source Code for Machine Learning in Action Ch. 3 @author: Peter Harrington/jiangzhonglian ''' import matplotlib.pyplot as plt # 定义文本框 和 箭头格式 【 sawtooth 波浪方框, round4 矩形方框 , fc表示字体颜色的深浅 0.1~0.9 依次变浅,没错是变浅】 decisionNode = dict(boxstyle="sawtooth", fc="0.8") leafNode = dict(boxstyle="round4", fc="0.8") arrow_args = dict(arrowstyle="<-") def getNumLeafs(myTree): numLeafs = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] # 根节点开始遍历 for key in secondDict.keys(): # 判断子节点是否为dict, 不是+1 if type(secondDict[key]).__name__ == 'dict': numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafs def getTreeDepth(myTree): maxDepth = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] # 根节点开始遍历 for key in secondDict.keys(): # 判断子节点是不是dict, 求分枝的深度 if type(secondDict[key]).__name__ == 'dict': thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 # 记录最大的分支深度 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args) def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) def plotTree(myTree, parentPt, nodeTxt): # 获取叶子节点的数量 numLeafs = getNumLeafs(myTree) # 获取树的深度 # depth = getTreeDepth(myTree) # 找出第1个中心点的位置,然后与 parentPt定点进行划线 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) # print cntrPt # 并打印输入对应的文字 plotMidText(cntrPt, parentPt, nodeTxt) firstStr = myTree.keys()[0] # 可视化Node分支点 plotNode(firstStr, cntrPt, parentPt, decisionNode) # 根节点的值 secondDict = myTree[firstStr] # y值 = 最高点-层数的高度[第二个节点位置] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): # 判断该节点是否是Node节点 if type(secondDict[key]).__name__ == 'dict': # 如果是就递归调用[recursion] plotTree(secondDict[key], cntrPt, str(key)) else: # 如果不是,就在原来节点一半的地方找到节点的坐标 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW # 可视化该节点位置 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) # 并打印输入对应的文字 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) # plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD def createPlot(inTree): # 创建一个figure的模版 fig = plt.figure(1, facecolor='green') fig.clf() axprops = dict(xticks=[], yticks=[]) # 表示创建一个1行,1列的图,createPlot.ax1 为第 1 个子图, createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) # 半个节点的长度 plotTree.xOff = -0.5/plotTree.totalW plotTree.yOff = 1.0 plotTree(inTree, (0.5, 1.0), '') plt.show() # # 测试画图 # def createPlot(): # fig = plt.figure(1, facecolor='white') # fig.clf() # # ticks for demo puropses # createPlot.ax1 = plt.subplot(111, frameon=False) # plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode) # plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode) # plt.show() # 测试数据集 def retrieveTree(i): listOfTrees = [ {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}} ] return listOfTrees[i] myTree = retrieveTree(0) createPlot(myTree)