mirror of
https://github.com/apachecn/ailearning.git
synced 2026-02-10 05:45:40 +08:00
决策树测试案例更新完成
This commit is contained in:
132
src/python/03.DecisionTree/DecisionTreePlot.py
Normal file
132
src/python/03.DecisionTree/DecisionTreePlot.py
Normal file
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/python
|
||||
# coding:utf8
|
||||
|
||||
'''
|
||||
Created on Oct 14, 2010
|
||||
Update on 2017-02-27
|
||||
Decision Tree Source Code for Machine Learning in Action Ch. 3
|
||||
@author: Peter Harrington/jiangzhonglian
|
||||
'''
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 定义文本框 和 箭头格式 【 sawtooth 波浪方框, round4 矩形方框 , fc表示字体颜色的深浅 0.1~0.9 依次变浅,没错是变浅】
|
||||
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
|
||||
leafNode = dict(boxstyle="round4", fc="0.8")
|
||||
arrow_args = dict(arrowstyle="<-")
|
||||
|
||||
|
||||
def getNumLeafs(myTree):
|
||||
numLeafs = 0
|
||||
firstStr = myTree.keys()[0]
|
||||
secondDict = myTree[firstStr]
|
||||
# 根节点开始遍历
|
||||
for key in secondDict.keys():
|
||||
# 判断子节点是否为dict, 不是+1
|
||||
if type(secondDict[key]).__name__ == 'dict':
|
||||
numLeafs += getNumLeafs(secondDict[key])
|
||||
else:
|
||||
numLeafs += 1
|
||||
return numLeafs
|
||||
|
||||
|
||||
def getTreeDepth(myTree):
|
||||
maxDepth = 0
|
||||
firstStr = myTree.keys()[0]
|
||||
secondDict = myTree[firstStr]
|
||||
# 根节点开始遍历
|
||||
for key in secondDict.keys():
|
||||
# 判断子节点是不是dict, 求分枝的深度
|
||||
if type(secondDict[key]).__name__ == 'dict':
|
||||
thisDepth = 1 + getTreeDepth(secondDict[key])
|
||||
else:
|
||||
thisDepth = 1
|
||||
# 记录最大的分支深度
|
||||
if thisDepth > maxDepth:
|
||||
maxDepth = thisDepth
|
||||
return maxDepth
|
||||
|
||||
|
||||
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
|
||||
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
|
||||
|
||||
|
||||
def plotMidText(cntrPt, parentPt, txtString):
|
||||
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
|
||||
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
|
||||
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
|
||||
|
||||
|
||||
def plotTree(myTree, parentPt, nodeTxt):
|
||||
# 获取叶子节点的数量
|
||||
numLeafs = getNumLeafs(myTree)
|
||||
# 获取树的深度
|
||||
# depth = getTreeDepth(myTree)
|
||||
|
||||
# 找出第1个中心点的位置,然后与 parentPt定点进行划线
|
||||
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
|
||||
# print cntrPt
|
||||
# 并打印输入对应的文字
|
||||
plotMidText(cntrPt, parentPt, nodeTxt)
|
||||
|
||||
firstStr = myTree.keys()[0]
|
||||
# 可视化Node分支点
|
||||
plotNode(firstStr, cntrPt, parentPt, decisionNode)
|
||||
# 根节点的值
|
||||
secondDict = myTree[firstStr]
|
||||
# y值 = 最高点-层数的高度[第二个节点位置]
|
||||
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
|
||||
for key in secondDict.keys():
|
||||
# 判断该节点是否是Node节点
|
||||
if type(secondDict[key]).__name__=='dict':
|
||||
# 如果是就递归调用[recursion]
|
||||
plotTree(secondDict[key],cntrPt,str(key))
|
||||
else:
|
||||
# 如果不是,就在原来节点一半的地方找到节点的坐标
|
||||
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
|
||||
# 可视化该节点位置
|
||||
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
|
||||
# 并打印输入对应的文字
|
||||
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
|
||||
# plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
|
||||
|
||||
|
||||
def createPlot(inTree):
|
||||
# 创建一个figure的模版
|
||||
fig = plt.figure(1, facecolor='green')
|
||||
fig.clf()
|
||||
|
||||
axprops = dict(xticks=[], yticks=[])
|
||||
# 表示创建一个1行,1列的图,createPlot.ax1 为第 1 个子图,
|
||||
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
|
||||
|
||||
plotTree.totalW = float(getNumLeafs(inTree))
|
||||
plotTree.totalD = float(getTreeDepth(inTree))
|
||||
# 半个节点的长度
|
||||
plotTree.xOff = -0.5/plotTree.totalW
|
||||
plotTree.yOff = 1.0
|
||||
plotTree(inTree, (0.5, 1.0), '')
|
||||
plt.show()
|
||||
|
||||
|
||||
# # 测试画图
|
||||
# def createPlot():
|
||||
# fig = plt.figure(1, facecolor='white')
|
||||
# fig.clf()
|
||||
# # ticks for demo puropses
|
||||
# createPlot.ax1 = plt.subplot(111, frameon=False)
|
||||
# plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
|
||||
# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
|
||||
# plt.show()
|
||||
|
||||
|
||||
# 测试数据集
|
||||
def retrieveTree(i):
|
||||
listOfTrees =[
|
||||
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
|
||||
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
|
||||
]
|
||||
return listOfTrees[i]
|
||||
|
||||
|
||||
myTree = retrieveTree(0)
|
||||
createPlot(myTree)
|
||||
Reference in New Issue
Block a user