mirror of
https://github.com/apachecn/ailearning.git
synced 2026-02-10 05:45:40 +08:00
133 lines
4.4 KiB
Python
133 lines
4.4 KiB
Python
#!/usr/bin/python
|
||
# coding:utf8
|
||
|
||
'''
|
||
Created on Oct 14, 2010
|
||
Update on 2017-02-27
|
||
Decision Tree Source Code for Machine Learning in Action Ch. 3
|
||
@author: Peter Harrington/jiangzhonglian
|
||
'''
|
||
import matplotlib.pyplot as plt
|
||
|
||
# 定义文本框 和 箭头格式 【 sawtooth 波浪方框, round4 矩形方框 , fc表示字体颜色的深浅 0.1~0.9 依次变浅,没错是变浅】
|
||
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
|
||
leafNode = dict(boxstyle="round4", fc="0.8")
|
||
arrow_args = dict(arrowstyle="<-")
|
||
|
||
|
||
def getNumLeafs(myTree):
|
||
numLeafs = 0
|
||
firstStr = myTree.keys()[0]
|
||
secondDict = myTree[firstStr]
|
||
# 根节点开始遍历
|
||
for key in secondDict.keys():
|
||
# 判断子节点是否为dict, 不是+1
|
||
if type(secondDict[key]).__name__ == 'dict':
|
||
numLeafs += getNumLeafs(secondDict[key])
|
||
else:
|
||
numLeafs += 1
|
||
return numLeafs
|
||
|
||
|
||
def getTreeDepth(myTree):
|
||
maxDepth = 0
|
||
firstStr = myTree.keys()[0]
|
||
secondDict = myTree[firstStr]
|
||
# 根节点开始遍历
|
||
for key in secondDict.keys():
|
||
# 判断子节点是不是dict, 求分枝的深度
|
||
if type(secondDict[key]).__name__ == 'dict':
|
||
thisDepth = 1 + getTreeDepth(secondDict[key])
|
||
else:
|
||
thisDepth = 1
|
||
# 记录最大的分支深度
|
||
if thisDepth > maxDepth:
|
||
maxDepth = thisDepth
|
||
return maxDepth
|
||
|
||
|
||
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
|
||
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
|
||
|
||
|
||
def plotMidText(cntrPt, parentPt, txtString):
|
||
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
|
||
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
|
||
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
|
||
|
||
|
||
def plotTree(myTree, parentPt, nodeTxt):
|
||
# 获取叶子节点的数量
|
||
numLeafs = getNumLeafs(myTree)
|
||
# 获取树的深度
|
||
# depth = getTreeDepth(myTree)
|
||
|
||
# 找出第1个中心点的位置,然后与 parentPt定点进行划线
|
||
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
|
||
# print cntrPt
|
||
# 并打印输入对应的文字
|
||
plotMidText(cntrPt, parentPt, nodeTxt)
|
||
|
||
firstStr = myTree.keys()[0]
|
||
# 可视化Node分支点
|
||
plotNode(firstStr, cntrPt, parentPt, decisionNode)
|
||
# 根节点的值
|
||
secondDict = myTree[firstStr]
|
||
# y值 = 最高点-层数的高度[第二个节点位置]
|
||
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
|
||
for key in secondDict.keys():
|
||
# 判断该节点是否是Node节点
|
||
if type(secondDict[key]).__name__ == 'dict':
|
||
# 如果是就递归调用[recursion]
|
||
plotTree(secondDict[key], cntrPt, str(key))
|
||
else:
|
||
# 如果不是,就在原来节点一半的地方找到节点的坐标
|
||
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
|
||
# 可视化该节点位置
|
||
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
|
||
# 并打印输入对应的文字
|
||
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
|
||
# plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
|
||
|
||
|
||
def createPlot(inTree):
|
||
# 创建一个figure的模版
|
||
fig = plt.figure(1, facecolor='green')
|
||
fig.clf()
|
||
|
||
axprops = dict(xticks=[], yticks=[])
|
||
# 表示创建一个1行,1列的图,createPlot.ax1 为第 1 个子图,
|
||
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
|
||
|
||
plotTree.totalW = float(getNumLeafs(inTree))
|
||
plotTree.totalD = float(getTreeDepth(inTree))
|
||
# 半个节点的长度
|
||
plotTree.xOff = -0.5/plotTree.totalW
|
||
plotTree.yOff = 1.0
|
||
plotTree(inTree, (0.5, 1.0), '')
|
||
plt.show()
|
||
|
||
|
||
# # 测试画图
|
||
# def createPlot():
|
||
# fig = plt.figure(1, facecolor='white')
|
||
# fig.clf()
|
||
# # ticks for demo puropses
|
||
# createPlot.ax1 = plt.subplot(111, frameon=False)
|
||
# plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
|
||
# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
|
||
# plt.show()
|
||
|
||
|
||
# 测试数据集
|
||
def retrieveTree(i):
|
||
listOfTrees = [
|
||
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
|
||
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
|
||
]
|
||
return listOfTrees[i]
|
||
|
||
|
||
myTree = retrieveTree(0)
|
||
createPlot(myTree)
|