更新 3.决策树 md

This commit is contained in:
jiangzhonglian
2017-08-23 18:38:42 +08:00
parent 877fee7a9a
commit b0c57257ca
5 changed files with 125 additions and 48 deletions

View File

@@ -75,24 +75,24 @@ def calcShannonEnt(dataSet):
return shannonEnt
def splitDataSet(dataSet, axis, value):
"""splitDataSet(通过遍历dataSet数据集求出axis对应的colnum列的值为value的行)
就是依据axis列进行分类,如果axis列的数据等于 value的时候就要将 axis 划分到我们创建的新的数据集中
def splitDataSet(dataSet, index, value):
"""splitDataSet(通过遍历dataSet数据集求出index对应的colnum列的值为value的行)
就是依据index列进行分类,如果index列的数据等于 value的时候就要将 index 划分到我们创建的新的数据集中
Args:
dataSet 数据集 待划分的数据集
axis 表示每一行的axis列 划分数据集的特征
value 表示axis列对应的value值 需要返回的特征的值。
index 表示每一行的index列 划分数据集的特征
value 表示index列对应的value值 需要返回的特征的值。
Returns:
axis列为value的数据集【该数据集需要排除axis列】
index列为value的数据集【该数据集需要排除index列】
"""
retDataSet = []
for featVec in dataSet:
# axis列为value的数据集【该数据集需要排除axis列】
# 判断axis列的值是否为value
if featVec[axis] == value:
# chop out axis used for splitting
# [:axis]表示前axis行即若 axis 为2就是取 featVec 的前 axis
reducedFeatVec = featVec[:axis]
# index列为value的数据集【该数据集需要排除index列】
# 判断index列的值是否为value
if featVec[index] == value:
# chop out index used for splitting
# [:index]表示前index行即若 index 为2就是取 featVec 的前 index
reducedFeatVec = featVec[:index]
'''
请百度查询一下: extend和append的区别
list.append(object) 向列表中添加一个对象object
@@ -109,11 +109,11 @@ def splitDataSet(dataSet, axis, value):
结果:
[1, 2, 3]
[1, 2, 3, [4, 5, 6]]
[1, 2, 3, [4, 5, 6], 7, 8, 9
[1, 2, 3, [4, 5, 6], 7, 8, 9]
'''
reducedFeatVec.extend(featVec[axis+1:])
# [axis+1:]表示从跳过 axis 的 axis+1行取接下来的数据
# 收集结果值 axis列为value的行【该行需要排除axis列】
reducedFeatVec.extend(featVec[index+1:])
# [index+1:]表示从跳过 index 的 index+1行取接下来的数据
# 收集结果值 index列为value的行【该行需要排除index列】
retDataSet.append(reducedFeatVec)
return retDataSet
@@ -250,6 +250,7 @@ def grabTree(filename):
fr = open(filename)
return pickle.load(fr)
def fishTest():
# 1.创建数据和结果标签
myDat, labels = createDataSet()
@@ -269,10 +270,11 @@ def fishTest():
myTree = createTree(myDat, copy.deepcopy(labels))
print myTree
# [1, 1]表示要取的分支上的节点位置,对应的结果值
# print classify(myTree, labels, [1, 1])
print classify(myTree, labels, [1, 1])
# 画图可视化展现
# dtPlot.createPlot(myTree)
dtPlot.createPlot(myTree)
def ContactLensesTest():
"""
@@ -294,10 +296,9 @@ def ContactLensesTest():
lensesTree = createTree(lenses, lensesLabels)
print lensesTree
# 画图可视化展现
# dtPlot.createPlot(lensesTree)
dtPlot.createPlot(lensesTree)
if __name__ == "__main__":
fishTest()
# ContactLensesTest()
# fishTest()
ContactLensesTest()

View File

@@ -87,7 +87,7 @@ def plotTree(myTree, parentPt, nodeTxt):
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
# 并打印输入对应的文字
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
# plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):

Binary file not shown.